resolve groundingdino install problem
parent
dfc14911a3
commit
b30e51d38b
|
|
@ -8,7 +8,6 @@ with open(req_file) as file:
|
||||||
for lib in file:
|
for lib in file:
|
||||||
lib = lib.strip()
|
lib = lib.strip()
|
||||||
if not launch.is_installed(lib):
|
if not launch.is_installed(lib):
|
||||||
if lib == "groundingdino":
|
|
||||||
lib = "git+https://github.com/IDEA-Research/GroundingDINO"
|
|
||||||
launch.run_pip(
|
launch.run_pip(
|
||||||
f"install {lib}", f"sd-webui-segment-anything requirement: {lib}")
|
f"install {lib}",
|
||||||
|
f"sd-webui-segment-anything requirement: {lib}")
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1 @@
|
||||||
segment_anything
|
segment_anything
|
||||||
groundingdino
|
|
||||||
|
|
@ -8,12 +8,6 @@ from collections import OrderedDict
|
||||||
from modules import scripts, shared
|
from modules import scripts, shared
|
||||||
from modules.devices import device, torch_gc, cpu
|
from modules.devices import device, torch_gc, cpu
|
||||||
|
|
||||||
# Grounding DINO
|
|
||||||
import groundingdino.datasets.transforms as T
|
|
||||||
from groundingdino.models import build_model
|
|
||||||
from groundingdino.util.slconfig import SLConfig
|
|
||||||
from groundingdino.util.utils import clean_state_dict
|
|
||||||
|
|
||||||
|
|
||||||
dino_model_cache = OrderedDict()
|
dino_model_cache = OrderedDict()
|
||||||
dino_model_dir = os.path.join(scripts.basedir(), "models/grounding-dino")
|
dino_model_dir = os.path.join(scripts.basedir(), "models/grounding-dino")
|
||||||
|
|
@ -33,6 +27,23 @@ dino_model_info = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def install_goundingdino():
|
||||||
|
import launch
|
||||||
|
if launch.is_installed("groundingdino"):
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
launch.run_pip(
|
||||||
|
f"install git+https://github.com/IDEA-Research/GroundingDINO",
|
||||||
|
f"sd-webui-segment-anything requirement: groundingdino")
|
||||||
|
print("GroundingDINO install success.")
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
print(traceback.print_exc())
|
||||||
|
print("GroundingDINO install failed. Submit an issue to https://github.com/continue-revolution/sd-webui-segment-anything/issues.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def show_boxes(image_np, boxes, color=(255, 0, 0, 255), thickness=2, show_index=False):
|
def show_boxes(image_np, boxes, color=(255, 0, 0, 255), thickness=2, show_index=False):
|
||||||
if boxes is None:
|
if boxes is None:
|
||||||
return image_np
|
return image_np
|
||||||
|
|
@ -64,6 +75,9 @@ def load_dino_model(dino_checkpoint):
|
||||||
dino.to(device=device)
|
dino.to(device=device)
|
||||||
else:
|
else:
|
||||||
clear_dino_cache()
|
clear_dino_cache()
|
||||||
|
from groundingdino.models import build_model
|
||||||
|
from groundingdino.util.slconfig import SLConfig
|
||||||
|
from groundingdino.util.utils import clean_state_dict
|
||||||
args = SLConfig.fromfile(dino_model_info[dino_checkpoint]["config"])
|
args = SLConfig.fromfile(dino_model_info[dino_checkpoint]["config"])
|
||||||
dino = build_model(args)
|
dino = build_model(args)
|
||||||
checkpoint = torch.hub.load_state_dict_from_url(
|
checkpoint = torch.hub.load_state_dict_from_url(
|
||||||
|
|
@ -77,6 +91,7 @@ def load_dino_model(dino_checkpoint):
|
||||||
|
|
||||||
|
|
||||||
def load_dino_image(image_pil):
|
def load_dino_image(image_pil):
|
||||||
|
import groundingdino.datasets.transforms as T
|
||||||
transform = T.Compose(
|
transform = T.Compose(
|
||||||
[
|
[
|
||||||
T.RandomResize([800], max_size=1333),
|
T.RandomResize([800], max_size=1333),
|
||||||
|
|
@ -112,6 +127,9 @@ def get_grounding_output(model, image, caption, box_threshold):
|
||||||
|
|
||||||
|
|
||||||
def dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold):
|
def dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold):
|
||||||
|
install_success = install_goundingdino()
|
||||||
|
if not install_success:
|
||||||
|
return None, False
|
||||||
print("Running GroundingDINO Inference")
|
print("Running GroundingDINO Inference")
|
||||||
dino_image = load_dino_image(input_image.convert("RGB"))
|
dino_image = load_dino_image(input_image.convert("RGB"))
|
||||||
dino_model = load_dino_model(dino_model_name)
|
dino_model = load_dino_model(dino_model_name)
|
||||||
|
|
@ -127,4 +145,4 @@ def dino_predict_internal(input_image, dino_model_name, text_prompt, box_thresho
|
||||||
boxes_filt[i][2:] += boxes_filt[i][:2]
|
boxes_filt[i][2:] += boxes_filt[i][:2]
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch_gc()
|
torch_gc()
|
||||||
return boxes_filt
|
return boxes_filt, True
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from modules.paths import models_path
|
||||||
from segment_anything import SamPredictor, sam_model_registry
|
from segment_anything import SamPredictor, sam_model_registry
|
||||||
from scripts.dino import dino_model_list, dino_predict_internal, show_boxes, clear_dino_cache
|
from scripts.dino import dino_model_list, dino_predict_internal, show_boxes, clear_dino_cache
|
||||||
|
|
||||||
|
|
||||||
sam_model_cache = OrderedDict()
|
sam_model_cache = OrderedDict()
|
||||||
scripts_sam_model_dir = os.path.join(scripts.basedir(), "models/sam")
|
scripts_sam_model_dir = os.path.join(scripts.basedir(), "models/sam")
|
||||||
sd_sam_model_dir = os.path.join(models_path, "sam")
|
sd_sam_model_dir = os.path.join(models_path, "sam")
|
||||||
|
|
@ -132,11 +133,18 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
|
||||||
image_np = np.array(input_image)
|
image_np = np.array(input_image)
|
||||||
image_np_rgb = image_np[..., :3]
|
image_np_rgb = image_np[..., :3]
|
||||||
dino_enabled = dino_checkbox and text_prompt is not None
|
dino_enabled = dino_checkbox and text_prompt is not None
|
||||||
|
boxes_filt = None
|
||||||
boxes_filt = dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold) if dino_enabled else None
|
sam_predict_result = " done."
|
||||||
if dino_enabled and dino_preview_checkbox is not None and dino_preview_checkbox and dino_preview_boxes_selection is not None:
|
if dino_enabled:
|
||||||
valid_indices = [int(i) for i in dino_preview_boxes_selection if int(i) < boxes_filt.shape[0]]
|
boxes_filt, install_success = dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold)
|
||||||
boxes_filt = boxes_filt[valid_indices]
|
if install_success and dino_preview_checkbox is not None and dino_preview_checkbox and dino_preview_boxes_selection is not None:
|
||||||
|
valid_indices = [int(i) for i in dino_preview_boxes_selection if int(i) < boxes_filt.shape[0]]
|
||||||
|
boxes_filt = boxes_filt[valid_indices]
|
||||||
|
if not install_success:
|
||||||
|
if len(positive_points) == 0 and len(negative_points) == 0:
|
||||||
|
return [], "GroundingDINO installment has failed. Check your terminal for more detail and submit an issue to https://github.com/continue-revolution/sd-webui-segment-anything/issues."
|
||||||
|
else:
|
||||||
|
sam_predict_result += " However, GroundingDINO installment has failed. Your process automatically fall back to point prompt only. Check your terminal for more detail and submit an issue to https://github.com/continue-revolution/sd-webui-segment-anything/issues."
|
||||||
|
|
||||||
sam = init_sam_model(sam_model_name)
|
sam = init_sam_model(sam_model_name)
|
||||||
|
|
||||||
|
|
@ -145,7 +153,8 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
|
||||||
predictor.set_image(image_np_rgb)
|
predictor.set_image(image_np_rgb)
|
||||||
|
|
||||||
if dino_enabled and boxes_filt.shape[0] > 1:
|
if dino_enabled and boxes_filt.shape[0] > 1:
|
||||||
print(f"SAM inference with {boxes_filt.shape[0]} boxes, point prompts disgarded.")
|
sam_predict_status = f"SAM inference with {boxes_filt.shape[0]} boxes, point prompts disgarded"
|
||||||
|
print(sam_predict_status)
|
||||||
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2])
|
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2])
|
||||||
masks, _, _ = predictor.predict_torch(
|
masks, _, _ = predictor.predict_torch(
|
||||||
point_coords=None,
|
point_coords=None,
|
||||||
|
|
@ -156,7 +165,8 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
|
||||||
|
|
||||||
masks = masks.permute(1, 0, 2, 3).cpu().numpy()
|
masks = masks.permute(1, 0, 2, 3).cpu().numpy()
|
||||||
else:
|
else:
|
||||||
print(f"SAM inference with {0 if boxes_filt is None else boxes_filt.shape[0]} box, {len(positive_points)} positive prompts, {len(negative_points)} negative prompts.")
|
sam_predict_status = f"SAM inference with {0 if boxes_filt is None else boxes_filt.shape[0]} box, {len(positive_points)} positive prompts, {len(negative_points)} negative prompts"
|
||||||
|
print(sam_predict_status)
|
||||||
point_coords = np.array(positive_points + negative_points)
|
point_coords = np.array(positive_points + negative_points)
|
||||||
point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
|
point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
|
||||||
|
|
||||||
|
|
@ -190,14 +200,17 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
|
||||||
image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0])
|
image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0])
|
||||||
matted_images.append(Image.fromarray(image_np_copy))
|
matted_images.append(Image.fromarray(image_np_copy))
|
||||||
|
|
||||||
return mask_images + masks_gallery + matted_images
|
return mask_images + masks_gallery + matted_images, sam_predict_status + sam_predict_result
|
||||||
|
|
||||||
|
|
||||||
def dino_predict(input_image, dino_model_name, text_prompt, box_threshold):
|
def dino_predict(input_image, dino_model_name, text_prompt, box_threshold):
|
||||||
image_np = np.array(input_image)
|
image_np = np.array(input_image)
|
||||||
boxes_filt = dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold).numpy()
|
boxes_filt, install_success = dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold)
|
||||||
|
if not install_success:
|
||||||
|
return None, gr.update(), gr.update(visible=True, value="GroundingDINO installment failed. Preview failed. See your terminal for more detail and submit an issue to https://github.com/continue-revolution/sd-webui-segment-anything/issues.")
|
||||||
|
boxes_filt = boxes_filt.numpy()
|
||||||
boxes_choice = [str(i) for i in range(boxes_filt.shape[0])]
|
boxes_choice = [str(i) for i in range(boxes_filt.shape[0])]
|
||||||
return Image.fromarray(show_boxes(image_np, boxes_filt.astype(int), show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice)
|
return Image.fromarray(show_boxes(image_np, boxes_filt.astype(int), show_index=True)), gr.update(choices=boxes_choice, value=boxes_choice), gr.update(visible=False)
|
||||||
|
|
||||||
def dino_batch_process(
|
def dino_batch_process(
|
||||||
batch_sam_model_name, batch_dino_model_name, batch_text_prompt, batch_box_threshold, batch_dilation_amt,
|
batch_sam_model_name, batch_dino_model_name, batch_text_prompt, batch_box_threshold, batch_dilation_amt,
|
||||||
|
|
@ -216,7 +229,9 @@ def dino_batch_process(
|
||||||
image_np = np.array(input_image)
|
image_np = np.array(input_image)
|
||||||
image_np_rgb = image_np[..., :3]
|
image_np_rgb = image_np[..., :3]
|
||||||
|
|
||||||
boxes_filt = dino_predict_internal(input_image, batch_dino_model_name, batch_text_prompt, batch_box_threshold)
|
boxes_filt, install_success = dino_predict_internal(input_image, batch_dino_model_name, batch_text_prompt, batch_box_threshold)
|
||||||
|
if not install_success:
|
||||||
|
return "GroundingDINO installment failed. Batch processing failed. See your terminal for more detail and submit an issue to https://github.com/continue-revolution/sd-webui-segment-anything/issues."
|
||||||
|
|
||||||
predictor.set_image(image_np_rgb)
|
predictor.set_image(image_np_rgb)
|
||||||
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2])
|
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image_np.shape[:2])
|
||||||
|
|
@ -301,20 +316,21 @@ class Script(scripts.Script):
|
||||||
dino_preview_boxes_button = gr.Button(value="Generate bounding box", elem_id="dino_run_button")
|
dino_preview_boxes_button = gr.Button(value="Generate bounding box", elem_id="dino_run_button")
|
||||||
dino_preview_boxes_selection = gr.CheckboxGroup(label="Select your favorite boxes: ", elem_id="dino_preview_boxes_selection")
|
dino_preview_boxes_selection = gr.CheckboxGroup(label="Select your favorite boxes: ", elem_id="dino_preview_boxes_selection")
|
||||||
dino_preview_boxes_selection.change(fn=lambda _: None, inputs=[dino_preview_boxes_selection], outputs=None, _js="onChangeDinoPreviewBoxesSelection")
|
dino_preview_boxes_selection.change(fn=lambda _: None, inputs=[dino_preview_boxes_selection], outputs=None, _js="onChangeDinoPreviewBoxesSelection")
|
||||||
|
dino_preview_result = gr.Text(value="", show_label=False, visible=False)
|
||||||
|
|
||||||
dino_preview_boxes_button.click(
|
dino_preview_boxes_button.click(
|
||||||
fn=dino_predict,
|
fn=dino_predict,
|
||||||
_js="submit_dino",
|
_js="submit_dino",
|
||||||
inputs=[input_image, dino_model_name, text_prompt, box_threshold],
|
inputs=[input_image, dino_model_name, text_prompt, box_threshold],
|
||||||
outputs=[dino_preview_boxes, dino_preview_boxes_selection]
|
outputs=[dino_preview_boxes, dino_preview_boxes_selection, dino_preview_result])
|
||||||
)
|
|
||||||
|
|
||||||
mask_image = gr.Gallery(label='Segment Anything Output', show_label=False, elem_id='sam_gallery').style(grid=3)
|
mask_image = gr.Gallery(label='Segment Anything Output', show_label=False, elem_id='sam_gallery').style(grid=3)
|
||||||
|
|
||||||
with gr.Row(elem_id="sam_generate_box", elem_classes="generate-box"):
|
with gr.Row(elem_id="sam_generate_box", elem_classes="generate-box"):
|
||||||
gr.Button(value="Add dot prompt or enable GroundingDINO with text prompts to preview segmentation", elem_id="sam_no_button")
|
gr.Button(value="Add dot prompt or enable GroundingDINO with text prompts to preview segmentation", elem_id="sam_no_button")
|
||||||
run_button = gr.Button(value="Preview Segmentation", elem_id="sam_run_button")
|
run_button = gr.Button(value="Preview Segmentation", elem_id="sam_run_button")
|
||||||
|
run_result = gr.Text(value="", show_label=False)
|
||||||
|
|
||||||
gr.Checkbox(value=False, label="Preview automatically when add/remove points", elem_id="sam_realtime_preview_checkbox")
|
gr.Checkbox(value=False, label="Preview automatically when add/remove points", elem_id="sam_realtime_preview_checkbox")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
@ -376,7 +392,7 @@ class Script(scripts.Script):
|
||||||
dummy_component, dummy_component, # Point prompts
|
dummy_component, dummy_component, # Point prompts
|
||||||
dino_checkbox, dino_model_name, text_prompt, box_threshold, # DINO prompts
|
dino_checkbox, dino_model_name, text_prompt, box_threshold, # DINO prompts
|
||||||
dino_preview_checkbox, dino_preview_boxes_selection], # DINO preview prompts
|
dino_preview_checkbox, dino_preview_boxes_selection], # DINO preview prompts
|
||||||
outputs=[mask_image],)
|
outputs=[mask_image, run_result])
|
||||||
|
|
||||||
dino_checkbox.change(
|
dino_checkbox.change(
|
||||||
fn=gr_show,
|
fn=gr_show,
|
||||||
|
|
@ -394,21 +410,18 @@ class Script(scripts.Script):
|
||||||
fn=lambda _: None,
|
fn=lambda _: None,
|
||||||
_js="switchToInpaintUpload",
|
_js="switchToInpaintUpload",
|
||||||
inputs=[dummy_component],
|
inputs=[dummy_component],
|
||||||
outputs=None
|
outputs=None)
|
||||||
)
|
|
||||||
|
|
||||||
remove_dots.click(
|
remove_dots.click(
|
||||||
fn=lambda _: None,
|
fn=lambda _: None,
|
||||||
_js="removeDots",
|
_js="removeDots",
|
||||||
inputs=[dummy_component],
|
inputs=[dummy_component],
|
||||||
outputs=None
|
outputs=None)
|
||||||
)
|
|
||||||
|
|
||||||
unload.click(
|
unload.click(
|
||||||
fn=clear_cache,
|
fn=clear_cache,
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[]
|
outputs=[])
|
||||||
)
|
|
||||||
|
|
||||||
dilation_checkbox.change(
|
dilation_checkbox.change(
|
||||||
fn=gr_show,
|
fn=gr_show,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue