fix bounding box draw bug
parent
370663a13a
commit
520e7fa74f
|
|
@ -40,7 +40,7 @@ def show_boxes(image_np, boxes, color=(255, 0, 0, 255), thickness=2, show_index=
|
||||||
image = copy.deepcopy(image_np)
|
image = copy.deepcopy(image_np)
|
||||||
for idx, box in enumerate(boxes):
|
for idx, box in enumerate(boxes):
|
||||||
x, y, w, h = box
|
x, y, w, h = box
|
||||||
cv2.rectangle(image, (x, y), (x+w, y+h), color, thickness)
|
cv2.rectangle(image, (x, y), (w, h), color, thickness)
|
||||||
if show_index:
|
if show_index:
|
||||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
text = str(idx)
|
text = str(idx)
|
||||||
|
|
|
||||||
|
|
@ -146,22 +146,21 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
|
||||||
|
|
||||||
if dino_enabled and boxes_filt.shape[0] > 1:
|
if dino_enabled and boxes_filt.shape[0] > 1:
|
||||||
print(f"SAM inference with {boxes_filt.shape[0]} boxes, point prompts disgarded.")
|
print(f"SAM inference with {boxes_filt.shape[0]} boxes, point prompts disgarded.")
|
||||||
boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2])
|
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2])
|
||||||
masks, _, _ = predictor.predict_torch(
|
masks, _, _ = predictor.predict_torch(
|
||||||
point_coords=None,
|
point_coords=None,
|
||||||
point_labels=None,
|
point_labels=None,
|
||||||
boxes=boxes.to(device),
|
boxes=transformed_boxes.to(device),
|
||||||
multimask_output=True,
|
multimask_output=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
masks = masks.permute(1, 0, 2, 3).cpu().numpy()
|
masks = masks.permute(1, 0, 2, 3).cpu().numpy()
|
||||||
boxes = boxes.cpu().numpy().astype(int)
|
|
||||||
else:
|
else:
|
||||||
print(f"SAM inference with {0 if boxes_filt is None else boxes_filt.shape[0]} box, {len(positive_points)} positive prompts, {len(negative_points)} negative prompts.")
|
print(f"SAM inference with {0 if boxes_filt is None else boxes_filt.shape[0]} box, {len(positive_points)} positive prompts, {len(negative_points)} negative prompts.")
|
||||||
point_coords = np.array(positive_points + negative_points)
|
point_coords = np.array(positive_points + negative_points)
|
||||||
point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
|
point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
|
||||||
|
|
||||||
box = boxes_filt[0].numpy() if boxes_filt is not None and boxes_filt.shape[0] > 0 else None
|
box = copy.deepcopy(boxes_filt[0].numpy()) if boxes_filt is not None and boxes_filt.shape[0] > 0 else None
|
||||||
masks, _, _ = predictor.predict(
|
masks, _, _ = predictor.predict(
|
||||||
point_coords=point_coords if len(point_coords) > 0 else None,
|
point_coords=point_coords if len(point_coords) > 0 else None,
|
||||||
point_labels=point_labels if len(point_coords) > 0 else None,
|
point_labels=point_labels if len(point_coords) > 0 else None,
|
||||||
|
|
@ -170,7 +169,6 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
|
||||||
)
|
)
|
||||||
|
|
||||||
masks = masks[:, None, ...]
|
masks = masks[:, None, ...]
|
||||||
boxes = predictor.transform.apply_boxes(box, image_np.shape[:2]).astype(int) if box is not None else None
|
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram:
|
if shared.cmd_opts.lowvram:
|
||||||
sam.to(cpu)
|
sam.to(cpu)
|
||||||
|
|
@ -182,8 +180,9 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
|
||||||
masks_gallery = []
|
masks_gallery = []
|
||||||
matted_images = []
|
matted_images = []
|
||||||
|
|
||||||
|
boxes_filt = boxes_filt.numpy().astype(int) if boxes_filt is not None else None
|
||||||
for mask in masks:
|
for mask in masks:
|
||||||
blended_image = show_masks(show_boxes(image_np, boxes), mask)
|
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
|
||||||
masks_gallery.append(Image.fromarray(np.any(mask, axis=0)))
|
masks_gallery.append(Image.fromarray(np.any(mask, axis=0)))
|
||||||
mask_images.append(Image.fromarray(blended_image))
|
mask_images.append(Image.fromarray(blended_image))
|
||||||
image_np_copy = copy.deepcopy(image_np)
|
image_np_copy = copy.deepcopy(image_np)
|
||||||
|
|
@ -193,26 +192,11 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
|
||||||
return mask_images + masks_gallery + matted_images
|
return mask_images + masks_gallery + matted_images
|
||||||
|
|
||||||
|
|
||||||
def dino_predict(sam_model_name, input_image, dino_model_name, text_prompt, box_threshold):
|
def dino_predict(input_image, dino_model_name, text_prompt, box_threshold):
|
||||||
image_np = np.array(input_image)
|
image_np = np.array(input_image)
|
||||||
image_np_rgb = image_np[..., :3]
|
boxes_filt = dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold).numpy()
|
||||||
boxes_filt = dino_predict_internal(
|
|
||||||
input_image, dino_model_name, text_prompt, box_threshold).numpy()
|
|
||||||
|
|
||||||
sam = init_sam_model(sam_model_name)
|
|
||||||
|
|
||||||
predictor = SamPredictor(sam)
|
|
||||||
predictor.set_image(image_np_rgb)
|
|
||||||
boxes_filt = predictor.transform.apply_boxes(
|
|
||||||
boxes_filt, image_np.shape[:2]).astype(int)
|
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram:
|
|
||||||
sam.to(cpu)
|
|
||||||
gc.collect()
|
|
||||||
torch_gc()
|
|
||||||
|
|
||||||
boxes_choice = [str(i) for i in range(boxes_filt.shape[0])]
|
boxes_choice = [str(i) for i in range(boxes_filt.shape[0])]
|
||||||
return Image.fromarray(show_boxes(image_np, boxes_filt, show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice)
|
return Image.fromarray(show_boxes(image_np, boxes_filt.astype(int), show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice)
|
||||||
|
|
||||||
def dino_batch_process(
|
def dino_batch_process(
|
||||||
batch_sam_model_name, batch_dino_model_name, batch_text_prompt, batch_box_threshold, batch_dilation_amt,
|
batch_sam_model_name, batch_dino_model_name, batch_text_prompt, batch_box_threshold, batch_dilation_amt,
|
||||||
|
|
@ -234,21 +218,21 @@ def dino_batch_process(
|
||||||
boxes_filt = dino_predict_internal(input_image, batch_dino_model_name, batch_text_prompt, batch_box_threshold)
|
boxes_filt = dino_predict_internal(input_image, batch_dino_model_name, batch_text_prompt, batch_box_threshold)
|
||||||
|
|
||||||
predictor.set_image(image_np_rgb)
|
predictor.set_image(image_np_rgb)
|
||||||
boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2])
|
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2])
|
||||||
masks, _, _ = predictor.predict_torch(
|
masks, _, _ = predictor.predict_torch(
|
||||||
point_coords=None,
|
point_coords=None,
|
||||||
point_labels=None,
|
point_labels=None,
|
||||||
boxes=boxes.to(device),
|
boxes=transformed_boxes.to(device),
|
||||||
multimask_output=(dino_batch_output_per_image == 1),
|
multimask_output=(dino_batch_output_per_image == 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
masks = masks.permute(1, 0, 2, 3).cpu().numpy()
|
masks = masks.permute(1, 0, 2, 3).cpu().numpy()
|
||||||
boxes = boxes.cpu().numpy().astype(int)
|
boxes_filt = boxes_filt.cpu().numpy().astype(int)
|
||||||
|
|
||||||
filename, ext = os.path.splitext(os.path.basename(input_image_file))
|
filename, ext = os.path.splitext(os.path.basename(input_image_file))
|
||||||
|
|
||||||
for idx, mask in enumerate(masks):
|
for idx, mask in enumerate(masks):
|
||||||
blended_image = show_masks(show_boxes(image_np, boxes), mask)
|
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
|
||||||
_, merged_mask = dilate_mask(np.any(mask, axis=0), batch_dilation_amt)
|
_, merged_mask = dilate_mask(np.any(mask, axis=0), batch_dilation_amt)
|
||||||
image_np_copy = copy.deepcopy(image_np)
|
image_np_copy = copy.deepcopy(image_np)
|
||||||
image_np_copy[~merged_mask] = np.array([0, 0, 0, 0])
|
image_np_copy[~merged_mask] = np.array([0, 0, 0, 0])
|
||||||
|
|
@ -318,7 +302,7 @@ class Script(scripts.Script):
|
||||||
dino_preview_boxes_button.click(
|
dino_preview_boxes_button.click(
|
||||||
fn=dino_predict,
|
fn=dino_predict,
|
||||||
_js="submit_dino",
|
_js="submit_dino",
|
||||||
inputs=[sam_model_name, input_image, dino_model_name, text_prompt, box_threshold],
|
inputs=[input_image, dino_model_name, text_prompt, box_threshold],
|
||||||
outputs=[dino_preview_boxes, dino_preview_boxes_selection]
|
outputs=[dino_preview_boxes, dino_preview_boxes_selection]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue