From 520e7fa74f2c6cb05b4fc84d4e8004c30bf87c97 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 15 Apr 2023 04:59:44 +0800 Subject: [PATCH] fix bounding box draw bug --- scripts/dino.py | 2 +- scripts/sam.py | 44 ++++++++++++++------------------------------ 2 files changed, 15 insertions(+), 31 deletions(-) diff --git a/scripts/dino.py b/scripts/dino.py index cece84e..6b33c7b 100644 --- a/scripts/dino.py +++ b/scripts/dino.py @@ -40,7 +40,7 @@ def show_boxes(image_np, boxes, color=(255, 0, 0, 255), thickness=2, show_index= image = copy.deepcopy(image_np) for idx, box in enumerate(boxes): x, y, w, h = box - cv2.rectangle(image, (x, y), (x+w, y+h), color, thickness) + cv2.rectangle(image, (x, y), (w, h), color, thickness) if show_index: font = cv2.FONT_HERSHEY_SIMPLEX text = str(idx) diff --git a/scripts/sam.py b/scripts/sam.py index 2d08433..2090bea 100644 --- a/scripts/sam.py +++ b/scripts/sam.py @@ -146,22 +146,21 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points, if dino_enabled and boxes_filt.shape[0] > 1: print(f"SAM inference with {boxes_filt.shape[0]} boxes, point prompts disgarded.") - boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2]) + transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2]) masks, _, _ = predictor.predict_torch( point_coords=None, point_labels=None, - boxes=boxes.to(device), + boxes=transformed_boxes.to(device), multimask_output=True, ) masks = masks.permute(1, 0, 2, 3).cpu().numpy() - boxes = boxes.cpu().numpy().astype(int) else: print(f"SAM inference with {0 if boxes_filt is None else boxes_filt.shape[0]} box, {len(positive_points)} positive prompts, {len(negative_points)} negative prompts.") point_coords = np.array(positive_points + negative_points) point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points)) - box = boxes_filt[0].numpy() if boxes_filt is not None and boxes_filt.shape[0] > 0 else None + box = copy.deepcopy(boxes_filt[0].numpy()) if boxes_filt is not None and boxes_filt.shape[0] > 0 else None masks, _, _ = predictor.predict( point_coords=point_coords if len(point_coords) > 0 else None, point_labels=point_labels if len(point_coords) > 0 else None, @@ -170,7 +169,6 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points, ) masks = masks[:, None, ...] - boxes = predictor.transform.apply_boxes(box, image_np.shape[:2]).astype(int) if box is not None else None if shared.cmd_opts.lowvram: sam.to(cpu) @@ -181,9 +179,10 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points, mask_images = [] masks_gallery = [] matted_images = [] - + + boxes_filt = boxes_filt.numpy().astype(int) if boxes_filt is not None else None for mask in masks: - blended_image = show_masks(show_boxes(image_np, boxes), mask) + blended_image = show_masks(show_boxes(image_np, boxes_filt), mask) masks_gallery.append(Image.fromarray(np.any(mask, axis=0))) mask_images.append(Image.fromarray(blended_image)) image_np_copy = copy.deepcopy(image_np) @@ -193,26 +192,11 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points, return mask_images + masks_gallery + matted_images -def dino_predict(sam_model_name, input_image, dino_model_name, text_prompt, box_threshold): +def dino_predict(input_image, dino_model_name, text_prompt, box_threshold): image_np = np.array(input_image) - image_np_rgb = image_np[..., :3] - boxes_filt = dino_predict_internal( - input_image, dino_model_name, text_prompt, box_threshold).numpy() - - sam = init_sam_model(sam_model_name) - - predictor = SamPredictor(sam) - predictor.set_image(image_np_rgb) - boxes_filt = predictor.transform.apply_boxes( - boxes_filt, image_np.shape[:2]).astype(int) - - if shared.cmd_opts.lowvram: - sam.to(cpu) - gc.collect() - torch_gc() - + boxes_filt = dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold).numpy() boxes_choice = [str(i) for i in range(boxes_filt.shape[0])] - return Image.fromarray(show_boxes(image_np, boxes_filt, show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice) + return Image.fromarray(show_boxes(image_np, boxes_filt.astype(int), show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice) def dino_batch_process( batch_sam_model_name, batch_dino_model_name, batch_text_prompt, batch_box_threshold, batch_dilation_amt, @@ -234,21 +218,21 @@ def dino_batch_process( boxes_filt = dino_predict_internal(input_image, batch_dino_model_name, batch_text_prompt, batch_box_threshold) predictor.set_image(image_np_rgb) - boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2]) + transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2]) masks, _, _ = predictor.predict_torch( point_coords=None, point_labels=None, - boxes=boxes.to(device), + boxes=transformed_boxes.to(device), multimask_output=(dino_batch_output_per_image == 1), ) masks = masks.permute(1, 0, 2, 3).cpu().numpy() - boxes = boxes.cpu().numpy().astype(int) + boxes_filt = boxes_filt.cpu().numpy().astype(int) filename, ext = os.path.splitext(os.path.basename(input_image_file)) for idx, mask in enumerate(masks): - blended_image = show_masks(show_boxes(image_np, boxes), mask) + blended_image = show_masks(show_boxes(image_np, boxes_filt), mask) _, merged_mask = dilate_mask(np.any(mask, axis=0), batch_dilation_amt) image_np_copy = copy.deepcopy(image_np) image_np_copy[~merged_mask] = np.array([0, 0, 0, 0]) @@ -318,7 +302,7 @@ class Script(scripts.Script): dino_preview_boxes_button.click( fn=dino_predict, _js="submit_dino", - inputs=[sam_model_name, input_image, dino_model_name, text_prompt, box_threshold], + inputs=[input_image, dino_model_name, text_prompt, box_threshold], outputs=[dino_preview_boxes, dino_preview_boxes_selection] )