sd-webui-segment-anything/sam_utils/util.py

205 lines
8.9 KiB
Python

from typing import Tuple, List
import os
import numpy as np
import cv2
import copy
from PIL import Image
class ModelInfo:
def __init__(self, model_type: str, url: str, author: str, size: str, download_info: str="auto download"):
self.model_type = model_type
self.url = url
self.author = author
self.size = size
self.download_info = download_info
def get_info(self, model_name: str):
return f"{model_name} ({self.size}, {self.author}, {self.model_type}, {self.download_info})"
def local_path(self, path: str):
self.url = path
self.download_info = "downloaded"
def show_boxes(image_np: np.ndarray, boxes: np.ndarray, color=(255, 0, 0, 255), thickness=2, show_index=False) -> np.ndarray:
if boxes is None:
return image_np
image = copy.deepcopy(image_np)
for idx, box in enumerate(boxes):
x, y, w, h = box
cv2.rectangle(image, (x, y), (w, h), color, thickness)
if show_index:
font = cv2.FONT_HERSHEY_SIMPLEX
text = str(idx)
textsize = cv2.getTextSize(text, font, 1, 2)[0]
cv2.putText(image, text, (x, y+textsize[1]), font, 1, color, thickness)
return image
def show_masks(image_np: np.ndarray, masks: np.ndarray, alpha=0.5) -> np.ndarray:
image = copy.deepcopy(image_np)
np.random.seed(0)
for mask in masks:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
image[mask] = image[mask] * (1 - alpha) + 255 * color.reshape(1, 1, -1) * alpha
return image.astype(np.uint8)
def dilate_mask(mask: np.ndarray, dilation_amt: int) -> Tuple[Image.Image, np.ndarray]:
from scipy.ndimage import binary_dilation
x, y = np.meshgrid(np.arange(dilation_amt), np.arange(dilation_amt))
center = dilation_amt // 2
dilation_kernel = ((x - center)**2 + (y - center)**2 <= center**2).astype(np.uint8)
dilated_binary_img = binary_dilation(mask, dilation_kernel)
dilated_mask = Image.fromarray(dilated_binary_img.astype(np.uint8) * 255)
return dilated_mask, dilated_binary_img
def update_mask(mask_gallery, chosen_mask: int, dilation_amt: float, input_image: Image.Image):
if isinstance(mask_gallery, list):
mask_image = Image.open(mask_gallery[chosen_mask + 3]['name'])
else:
mask_image = mask_gallery
binary_img = np.array(mask_image.convert('1'))
if dilation_amt:
mask_image, binary_img = dilate_mask(binary_img, dilation_amt)
blended_image = Image.fromarray(show_masks(np.array(input_image), binary_img.astype(np.bool_)[None, ...]))
matted_image = np.array(input_image)
matted_image[~binary_img] = np.array([0, 0, 0, 0])
return [blended_image, mask_image, Image.fromarray(matted_image)]
def create_mask_output(image_np: np.ndarray, masks: np.ndarray, boxes_filt: np.ndarray) -> List[Image.Image]:
mask_images, masks_gallery, matted_images = [], [], []
boxes_filt = boxes_filt.astype(int) if boxes_filt is not None else None
for mask in masks:
masks_gallery.append(Image.fromarray(np.any(mask, axis=0)))
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
mask_images.append(Image.fromarray(blended_image))
image_np_copy = copy.deepcopy(image_np)
image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0])
matted_images.append(Image.fromarray(image_np_copy))
return mask_images + masks_gallery + matted_images
def create_mask_batch_output(
input_image_filename: str, dest_dir: str,
image_np: np.ndarray, masks: np.ndarray, boxes_filt: np.ndarray, dilation_amt: float,
save_image: bool, save_mask: bool, save_background: bool, save_image_with_mask: bool):
filename, ext = os.path.splitext(os.path.basename(input_image_filename))
ext = ".png" # JPEG not compatible with RGBA
for idx, mask in enumerate(masks):
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
merged_mask = np.any(mask, axis=0)
if save_background:
merged_mask = ~merged_mask
if dilation_amt:
_, merged_mask = dilate_mask(merged_mask, dilation_amt)
image_np_copy = copy.deepcopy(image_np)
image_np_copy[~merged_mask] = np.array([0, 0, 0, 0])
if save_image:
output_image = Image.fromarray(image_np_copy)
output_image.save(os.path.join(dest_dir, f"{filename}_{idx}_output{ext}"))
if save_mask:
output_mask = Image.fromarray(merged_mask)
output_mask.save(os.path.join(dest_dir, f"{filename}_{idx}_mask{ext}"))
if save_image_with_mask:
output_blend = Image.fromarray(blended_image)
output_blend.save(os.path.join(dest_dir, f"{filename}_{idx}_blend{ext}"))
def blend_image_and_seg(image: np.ndarray, seg: np.ndarray, alpha=0.5) -> Image.Image:
image_blend = image * (1 - alpha) + np.array(seg) * alpha
return Image.fromarray(image_blend.astype(np.uint8))
def install_pycocotools():
# install pycocotools if needed
from sam_utils.logger import logger
try:
import pycocotools.mask as maskUtils
except:
logger.warn("pycocotools not found, will try installing C++ based pycocotools")
try:
from launch import run_pip
run_pip(f"install pycocotools", f"AutoSAM requirement: pycocotools")
import pycocotools.mask as maskUtils
except:
import traceback
traceback.print_exc()
import sys
if sys.platform == "win32":
logger.warn("Unable to install pycocotools, will try installing pycocotools-windows")
try:
run_pip("install pycocotools-windows", "AutoSAM requirement: pycocotools-windows")
import pycocotools.mask as maskUtils
except:
error_msg = "Unable to install pycocotools-windows"
logger.error(error_msg)
traceback.print_exc()
raise RuntimeError(error_msg)
else:
error_msg = "Unable to install pycocotools"
logger.error(error_msg)
traceback.print_exc()
raise RuntimeError(error_msg)
def install_goundingdino() -> bool:
"""Automatically install GroundingDINO.
Returns:
bool: False if use local GroundingDINO, True if use pip installed GroundingDINO.
"""
from sam_utils.logger import logger
from modules import shared
dino_install_issue_text = "Please permanently switch to local GroundingDINO on Settings/Segment Anything or submit an issue to https://github.com/IDEA-Research/Grounded-Segment-Anything/issues."
if shared.opts.data.get("sam_use_local_groundingdino", False):
logger.info("Using local groundingdino.")
return False
def verify_dll(install_local=True):
try:
from groundingdino import _C
logger.info("GroundingDINO dynamic library have been successfully built.")
return True
except Exception:
import traceback
traceback.print_exc()
def run_pip_uninstall(command, desc=None):
from launch import python, run
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
return run(f'"{python}" -m pip uninstall -y {command}', desc=f"Uninstalling {desc}", errdesc=f"Couldn't uninstall {desc}", live=default_command_live)
if install_local:
logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and fall back to local GroundingDINO this time. {dino_install_issue_text}")
run_pip_uninstall(
f"groundingdino",
f"sd-webui-segment-anything requirement: groundingdino")
else:
logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and re-try installing from GitHub source code. {dino_install_issue_text}")
run_pip_uninstall(
f"uninstall groundingdino",
f"sd-webui-segment-anything requirement: groundingdino")
return False
import launch
if launch.is_installed("groundingdino"):
logger.info("Found GroundingDINO in pip. Verifying if dynamic library build success.")
if verify_dll(install_local=False):
return True
try:
launch.run_pip(
f"install git+https://github.com/IDEA-Research/GroundingDINO",
f"sd-webui-segment-anything requirement: groundingdino")
logger.info("GroundingDINO install success. Verifying if dynamic library build success.")
return verify_dll()
except Exception:
import traceback
traceback.print_exc()
logger.warn(f"GroundingDINO install failed. Will fall back to local groundingdino this time. {dino_install_issue_text}")
return False