From b30e51d38b58edb25cbbc107af9857829373bea9 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sat, 15 Apr 2023 20:18:31 +0800 Subject: [PATCH] resolve groundingdino install problem --- install.py | 5 ++--- requirements.txt | 3 +-- scripts/dino.py | 32 ++++++++++++++++++++++------ scripts/sam.py | 55 ++++++++++++++++++++++++++++++------------------ 4 files changed, 62 insertions(+), 33 deletions(-) diff --git a/install.py b/install.py index 984e858..5988f03 100644 --- a/install.py +++ b/install.py @@ -8,7 +8,6 @@ with open(req_file) as file: for lib in file: lib = lib.strip() if not launch.is_installed(lib): - if lib == "groundingdino": - lib = "git+https://github.com/IDEA-Research/GroundingDINO" launch.run_pip( - f"install {lib}", f"sd-webui-segment-anything requirement: {lib}") + f"install {lib}", + f"sd-webui-segment-anything requirement: {lib}") diff --git a/requirements.txt b/requirements.txt index 9417c1e..4032d22 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ -segment_anything -groundingdino \ No newline at end of file +segment_anything \ No newline at end of file diff --git a/scripts/dino.py b/scripts/dino.py index 6b33c7b..afe39a3 100644 --- a/scripts/dino.py +++ b/scripts/dino.py @@ -8,12 +8,6 @@ from collections import OrderedDict from modules import scripts, shared from modules.devices import device, torch_gc, cpu -# Grounding DINO -import groundingdino.datasets.transforms as T -from groundingdino.models import build_model -from groundingdino.util.slconfig import SLConfig -from groundingdino.util.utils import clean_state_dict - dino_model_cache = OrderedDict() dino_model_dir = os.path.join(scripts.basedir(), "models/grounding-dino") @@ -33,6 +27,23 @@ dino_model_info = { } +def install_goundingdino(): + import launch + if launch.is_installed("groundingdino"): + return True + try: + launch.run_pip( + f"install git+https://github.com/IDEA-Research/GroundingDINO", + f"sd-webui-segment-anything requirement: groundingdino") + print("GroundingDINO install success.") + return True + except Exception: + import traceback + print(traceback.print_exc()) + print("GroundingDINO install failed. Submit an issue to https://github.com/continue-revolution/sd-webui-segment-anything/issues.") + return False + + def show_boxes(image_np, boxes, color=(255, 0, 0, 255), thickness=2, show_index=False): if boxes is None: return image_np @@ -64,6 +75,9 @@ def load_dino_model(dino_checkpoint): dino.to(device=device) else: clear_dino_cache() + from groundingdino.models import build_model + from groundingdino.util.slconfig import SLConfig + from groundingdino.util.utils import clean_state_dict args = SLConfig.fromfile(dino_model_info[dino_checkpoint]["config"]) dino = build_model(args) checkpoint = torch.hub.load_state_dict_from_url( @@ -77,6 +91,7 @@ def load_dino_model(dino_checkpoint): def load_dino_image(image_pil): + import groundingdino.datasets.transforms as T transform = T.Compose( [ T.RandomResize([800], max_size=1333), @@ -112,6 +127,9 @@ def get_grounding_output(model, image, caption, box_threshold): def dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold): + install_success = install_goundingdino() + if not install_success: + return None, False print("Running GroundingDINO Inference") dino_image = load_dino_image(input_image.convert("RGB")) dino_model = load_dino_model(dino_model_name) @@ -127,4 +145,4 @@ def dino_predict_internal(input_image, dino_model_name, text_prompt, box_thresho boxes_filt[i][2:] += boxes_filt[i][:2] gc.collect() torch_gc() - return boxes_filt + return boxes_filt, True diff --git a/scripts/sam.py b/scripts/sam.py index 0154040..da745fa 100644 --- a/scripts/sam.py +++ b/scripts/sam.py @@ -17,6 +17,7 @@ from modules.paths import models_path from segment_anything import SamPredictor, sam_model_registry from scripts.dino import dino_model_list, dino_predict_internal, show_boxes, clear_dino_cache + sam_model_cache = OrderedDict() scripts_sam_model_dir = os.path.join(scripts.basedir(), "models/sam") sd_sam_model_dir = os.path.join(models_path, "sam") @@ -132,11 +133,18 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points, image_np = np.array(input_image) image_np_rgb = image_np[..., :3] dino_enabled = dino_checkbox and text_prompt is not None - - boxes_filt = dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold) if dino_enabled else None - if dino_enabled and dino_preview_checkbox is not None and dino_preview_checkbox and dino_preview_boxes_selection is not None: - valid_indices = [int(i) for i in dino_preview_boxes_selection if int(i) < boxes_filt.shape[0]] - boxes_filt = boxes_filt[valid_indices] + boxes_filt = None + sam_predict_result = " done." + if dino_enabled: + boxes_filt, install_success = dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold) + if install_success and dino_preview_checkbox is not None and dino_preview_checkbox and dino_preview_boxes_selection is not None: + valid_indices = [int(i) for i in dino_preview_boxes_selection if int(i) < boxes_filt.shape[0]] + boxes_filt = boxes_filt[valid_indices] + if not install_success: + if len(positive_points) == 0 and len(negative_points) == 0: + return [], "GroundingDINO installment has failed. Check your terminal for more detail and submit an issue to https://github.com/continue-revolution/sd-webui-segment-anything/issues." + else: + sam_predict_result += " However, GroundingDINO installment has failed. Your process automatically fall back to point prompt only. Check your terminal for more detail and submit an issue to https://github.com/continue-revolution/sd-webui-segment-anything/issues." sam = init_sam_model(sam_model_name) @@ -145,7 +153,8 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points, predictor.set_image(image_np_rgb) if dino_enabled and boxes_filt.shape[0] > 1: - print(f"SAM inference with {boxes_filt.shape[0]} boxes, point prompts disgarded.") + sam_predict_status = f"SAM inference with {boxes_filt.shape[0]} boxes, point prompts disgarded" + print(sam_predict_status) transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2]) masks, _, _ = predictor.predict_torch( point_coords=None, @@ -156,7 +165,8 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points, masks = masks.permute(1, 0, 2, 3).cpu().numpy() else: - print(f"SAM inference with {0 if boxes_filt is None else boxes_filt.shape[0]} box, {len(positive_points)} positive prompts, {len(negative_points)} negative prompts.") + sam_predict_status = f"SAM inference with {0 if boxes_filt is None else boxes_filt.shape[0]} box, {len(positive_points)} positive prompts, {len(negative_points)} negative prompts" + print(sam_predict_status) point_coords = np.array(positive_points + negative_points) point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points)) @@ -190,14 +200,17 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points, image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0]) matted_images.append(Image.fromarray(image_np_copy)) - return mask_images + masks_gallery + matted_images + return mask_images + masks_gallery + matted_images, sam_predict_status + sam_predict_result def dino_predict(input_image, dino_model_name, text_prompt, box_threshold): image_np = np.array(input_image) - boxes_filt = dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold).numpy() + boxes_filt, install_success = dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold) + if not install_success: + return None, gr.update(), gr.update(visible=True, value="GroundingDINO installment failed. Preview failed. See your terminal for more detail and submit an issue to https://github.com/continue-revolution/sd-webui-segment-anything/issues.") + boxes_filt = boxes_filt.numpy() boxes_choice = [str(i) for i in range(boxes_filt.shape[0])] - return Image.fromarray(show_boxes(image_np, boxes_filt.astype(int), show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice) + return Image.fromarray(show_boxes(image_np, boxes_filt.astype(int), show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice), gr.update(visible=False) def dino_batch_process( batch_sam_model_name, batch_dino_model_name, batch_text_prompt, batch_box_threshold, batch_dilation_amt, @@ -216,7 +229,9 @@ def dino_batch_process( image_np = np.array(input_image) image_np_rgb = image_np[..., :3] - boxes_filt = dino_predict_internal(input_image, batch_dino_model_name, batch_text_prompt, batch_box_threshold) + boxes_filt, install_success = dino_predict_internal(input_image, batch_dino_model_name, batch_text_prompt, batch_box_threshold) + if not install_success: + return "GroundingDINO installment failed. Batch processing failed. See your terminal for more detail and submit an issue to https://github.com/continue-revolution/sd-webui-segment-anything/issues." predictor.set_image(image_np_rgb) transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2]) @@ -301,20 +316,21 @@ class Script(scripts.Script): dino_preview_boxes_button = gr.Button(value="Generate bounding box", elem_id="dino_run_button") dino_preview_boxes_selection = gr.CheckboxGroup(label="Select your favorite boxes: ", elem_id="dino_preview_boxes_selection") dino_preview_boxes_selection.change(fn=lambda _: None, inputs=[dino_preview_boxes_selection], outputs=None, _js="onChangeDinoPreviewBoxesSelection") + dino_preview_result = gr.Text(value="", show_label=False, visible=False) dino_preview_boxes_button.click( fn=dino_predict, _js="submit_dino", inputs=[input_image, dino_model_name, text_prompt, box_threshold], - outputs=[dino_preview_boxes, dino_preview_boxes_selection] - ) + outputs=[dino_preview_boxes, dino_preview_boxes_selection, dino_preview_result]) mask_image = gr.Gallery(label='Segment Anything Output', show_label=False, elem_id='sam_gallery').style(grid=3) with gr.Row(elem_id="sam_generate_box", elem_classes="generate-box"): gr.Button(value="Add dot prompt or enable GroundingDINO with text prompts to preview segmentation", elem_id="sam_no_button") run_button = gr.Button(value="Preview Segmentation", elem_id="sam_run_button") - + run_result = gr.Text(value="", show_label=False) + gr.Checkbox(value=False, label="Preview automatically when add/remove points", elem_id="sam_realtime_preview_checkbox") with gr.Row(): @@ -376,7 +392,7 @@ class Script(scripts.Script): dummy_component, dummy_component, # Point prompts dino_checkbox, dino_model_name, text_prompt, box_threshold, # DINO prompts dino_preview_checkbox, dino_preview_boxes_selection], # DINO preview prompts - outputs=[mask_image],) + outputs=[mask_image, run_result]) dino_checkbox.change( fn=gr_show, @@ -394,21 +410,18 @@ class Script(scripts.Script): fn=lambda _: None, _js="switchToInpaintUpload", inputs=[dummy_component], - outputs=None - ) + outputs=None) remove_dots.click( fn=lambda _: None, _js="removeDots", inputs=[dummy_component], - outputs=None - ) + outputs=None) unload.click( fn=clear_cache, inputs=[], - outputs=[] - ) + outputs=[]) dilation_checkbox.change( fn=gr_show,