From d168dca02b25fd8eb57ed91dd8d3744ad706cbaa Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Wed, 19 Apr 2023 02:03:41 +0800 Subject: [PATCH] get rid of annoying js change button logic --- javascript/sam.js | 76 ++--------------------------------------------- scripts/sam.py | 27 +++++++++-------- style.css | 6 ---- 3 files changed, 16 insertions(+), 93 deletions(-) diff --git a/javascript/sam.js b/javascript/sam.js index 9b0a606..0c9a6e6 100644 --- a/javascript/sam.js +++ b/javascript/sam.js @@ -16,33 +16,6 @@ function samGetRealCoordinate(image, x1, y1) { } } -function samChangeRunButton() { - const sam_run_button = gradioApp().getElementById(samTabPrefix() + "run_button"); - const sam_mode = (samHasImageInput() && (samCanSubmit() || (dinoCanSubmit() && dinoPreviewCanSubmit()))) ? "block" : "none"; - if (sam_run_button && sam_run_button.style.display != sam_mode) { - sam_run_button.style.display = sam_mode; - } - - const dino_run_button = gradioApp().getElementById(samTabPrefix() + "dino_run_button"); - const dino_mode = (samHasImageInput() && dinoCanSubmit()) ? "block" : "none"; - if (dino_run_button && dino_run_button.style.display != dino_mode) { - dino_run_button.style.display = dino_mode; - } -} - -function dinoRegisterTextObserver() { - const dino_text_prompt = gradioApp().getElementById(samTabPrefix() + "dino_text_prompt").querySelector("textarea") - const observer = new MutationObserver(mutations => { - mutations.forEach(mutation => { - if (mutation.target === dino_text_prompt) { - samChangeRunButton(); - } - }); - }); - observer.observe(dino_text_prompt, { attributes: true }); - return arguments; -} - function switchToInpaintUpload() { switch_to_img2img_tab(4) return arguments; @@ -95,16 +68,13 @@ function samCreateDot(sam_image, image, coord, label) { circle.addEventListener("click", e => { e.stopPropagation(); circle.remove(); - if (gradioApp().querySelectorAll("." + samTabPrefix() + "positive").length == 0 && - gradioApp().querySelectorAll("." + samTabPrefix() + "negative").length == 0) { - samChangeRunButton(); - } else { + if (gradioApp().querySelectorAll("." + samTabPrefix() + "positive").length != 0 || + gradioApp().querySelectorAll("." + samTabPrefix() + "negative").length != 0) { if (samIsRealTimePreview()) { samImmediatelyGenerate(); } } }); - samChangeRunButton(); if (samIsRealTimePreview()) { samImmediatelyGenerate(); } @@ -172,43 +142,6 @@ function submit_sam() { return res } -function dinoPreviewCanSubmit() { - const dino_preview_enable_checkbox = gradioApp().getElementById(samTabPrefix() + "dino_preview_checkbox"); - if (!dino_preview_enable_checkbox || - (dino_preview_enable_checkbox.querySelector("input") && - !dino_preview_enable_checkbox.querySelector("input").checked)) { - return true; - } else { - let dino_preview_selected = false; - gradioApp().getElementById(samTabPrefix() + "dino_preview_boxes_selection").querySelectorAll("input").forEach(element => dino_preview_selected = element.checked || dino_preview_selected); - return dino_preview_selected; - } -} - -function dinoCanSubmit() { - const dino_enable_checkbox = gradioApp().getElementById(samTabPrefix() + "dino_enable_checkbox") - const dino_text_prompt = gradioApp().getElementById(samTabPrefix() + "dino_text_prompt") - return (dino_enable_checkbox && dino_text_prompt && - dino_enable_checkbox.querySelector("input") && - dino_enable_checkbox.querySelector("input").checked && - dino_text_prompt.querySelector("textarea") && - dino_text_prompt.querySelector("textarea").value != "") -} - -function samCanSubmit() { - return (gradioApp().querySelectorAll("." + samTabPrefix() + "positive").length > 0 || - gradioApp().querySelectorAll("." + samTabPrefix() + "negative").length > 0) -} - -function samHasImageInput() { - const sam_image = gradioApp().getElementById(samTabPrefix() + "input_image") - return sam_image && sam_image.querySelector('img') -} - -function dinoOnChangePreviewBoxesSelection() { - samChangeRunButton() -} - samPrevImg = { "txt2img_sam_": null, "img2img_sam_": null, @@ -219,7 +152,6 @@ onUiUpdate(() => { if (sam_image) { const image = sam_image.querySelector('img') if (image && samPrevImg[samTabPrefix()] != image.src) { - console.log("remove points 1") samRemoveDots(); samPrevImg[samTabPrefix()] = image.src; @@ -243,7 +175,6 @@ onUiUpdate(() => { const observer = new MutationObserver(mutations => { mutations.forEach(mutation => { if (mutation.type === 'attributes' && mutation.attributeName === 'src' && mutation.target === image) { - console.log("remove points 2") samRemoveDots(); samPrevImg[samTabPrefix()] = image.src; } @@ -252,11 +183,8 @@ onUiUpdate(() => { observer.observe(image, { attributes: true }); } else if (!image) { - console.log("remove points 3") samRemoveDots(); samPrevImg[samTabPrefix()] = null; } } - - samChangeRunButton(); }) diff --git a/scripts/sam.py b/scripts/sam.py index f070eb2..a14a735 100644 --- a/scripts/sam.py +++ b/scripts/sam.py @@ -138,6 +138,10 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points, dino_checkbox, dino_model_name, text_prompt, box_threshold, dino_preview_checkbox, dino_preview_boxes_selection, gui=True): print("Start SAM Processing") + if sam_model_name is None: + return [], "SAM model not found. Please download SAM model from extension README." + if input_image is None: + return [], "SAM requires an input image. Please upload an image first." image_np = np.array(input_image) image_np_rgb = image_np[..., :3] dino_enabled = dino_checkbox and text_prompt is not None @@ -177,7 +181,9 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points, num_points = len(positive_points) + len(negative_points) if num_box == 0 and num_points == 0: garbage_collect(sam) - return [], "It seems that you are using a high box threshold with no point prompts. Please lower your box threshold and re-try." + if dino_enabled and num_box == 0: + return [], "It seems that you are using a high box threshold with no point prompts. Please lower your box threshold and re-try." + return [], "You neither added point prompts nor enabled GroundingDINO. Segmentation cannot be generated." sam_predict_status = f"SAM inference with {num_box} box, {len(positive_points)} positive prompts, {len(negative_points)} negative prompts" print(sam_predict_status) point_coords = np.array(positive_points + negative_points) @@ -214,6 +220,10 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points, def dino_predict(input_image, dino_model_name, text_prompt, box_threshold): + if input_image is None: + return None, gr.update(), gr.update(visible=True, value=f"GroundingDINO requires input image.") + if text_prompt is None or text_prompt == "": + return None, gr.update(), gr.update(visible=True, value=f"GroundingDINO requires text prompt.") image_np = np.array(input_image) boxes_filt, install_success = dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold) if not install_success: @@ -348,20 +358,13 @@ class Script(scripts.Script): with gr.Column(visible=False) as dino_column: gr.HTML(value="

Due to the limitation of Segment Anything, when there are point prompts, at most 1 box prompt will be allowed; when there are multiple box prompts, no point prompts are allowed.

") dino_model_name = gr.Dropdown(label="GroundingDINO Model (Auto download from huggingface)", choices=dino_model_list, value=dino_model_list[0]) - - text_prompt = gr.Textbox(label="GroundingDINO Detection Prompt", elem_id=f"{tab_prefix}dino_text_prompt") - text_prompt.change(fn=lambda _: None, inputs=[dummy_component], outputs=None, _js="dinoRegisterTextObserver") - + text_prompt = gr.Textbox(placeholder="You must enter text prompts to enable groundingdino. Otherwise this extension will fall back to point prompts only.", label="GroundingDINO Detection Prompt", elem_id=f"{tab_prefix}dino_text_prompt") box_threshold = gr.Slider(label="GroundingDINO Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001) - dino_preview_checkbox = gr.Checkbox(value=False, label="I want to preview GroundingDINO detection result and select the boxes I want.", elem_id=f"{tab_prefix}dino_preview_checkbox") with gr.Column(visible=False) as dino_preview: dino_preview_boxes = gr.Image(label="Image for GroundingDINO", show_label=False, type="pil", image_mode="RGBA") - with gr.Row(elem_classes="generate-box"): - gr.Button(value="Add text prompt to generate bounding box", elem_id=f"{tab_prefix}dino_no_button") - dino_preview_boxes_button = gr.Button(value="Generate bounding box", elem_id=f"{tab_prefix}dino_run_button") + dino_preview_boxes_button = gr.Button(value="Generate bounding box", elem_id=f"{tab_prefix}dino_run_button") dino_preview_boxes_selection = gr.CheckboxGroup(label="Select your favorite boxes: ", elem_id=f"{tab_prefix}dino_preview_boxes_selection") - dino_preview_boxes_selection.change(fn=lambda _: None, inputs=[dino_preview_boxes_selection], outputs=None, _js="dinoOnChangePreviewBoxesSelection") dino_preview_result = gr.Text(value="", show_label=False, visible=False) dino_preview_boxes_button.click( @@ -372,9 +375,7 @@ class Script(scripts.Script): mask_image = gr.Gallery(label='Segment Anything Output', show_label=False).style(grid=3) - with gr.Row(elem_classes="generate-box"): - gr.Button(value="Add dot prompt or enable GroundingDINO with text prompts to preview segmentation", elem_id=f"{tab_prefix}sam_no_button") - run_button = gr.Button(value="Preview Segmentation", elem_id=f"{tab_prefix}run_button") + run_button = gr.Button(value="Preview Segmentation", elem_id=f"{tab_prefix}run_button") run_result = gr.Text(value="", show_label=False) gr.Checkbox(value=False, label="Preview automatically when add/remove points", elem_id=f"{tab_prefix}realtime_preview_checkbox") diff --git a/style.css b/style.css index 0036089..9072ec1 100644 --- a/style.css +++ b/style.css @@ -1,9 +1,3 @@ -#txt2img_sam_run_button, #img2img_sam_run_button, #txt2img_sam_dino_run_button, #img2img_sam_dino_run_button { - position: absolute; - display: none; - width: 100%; -} - #txt2img_sam_run_button:hover, #img2img_sam_run_button:hover, #txt2img_sam_dino_run_button:hover, #img2img_sam_dino_run_button:hover { background: #b4c0cc; }