another save

pull/149/head
Chengsong Zhang 2023-07-10 20:37:37 +08:00
parent a4058c4854
commit 6488b4e2dd
47 changed files with 1599 additions and 372 deletions

284
sam_utils/autosam.py Normal file
View File

@ -0,0 +1,284 @@
from typing import List, Tuple, Union, Optional
import os
import glob
import copy
from PIL import Image
import numpy as np
import torch
from segment_anything.modeling import Sam
from thirdparty.fastsam import FastSAM
from thirdparty.sam_hq.automatic import SamAutomaticMaskGeneratorHQ
from sam_utils.segment import Segmentation
from sam_utils.logger import logger
from sam_utils.util import blend_image_and_seg
class AutoSAM:
"""Automatic segmentation."""
def __init__(self, sam: Segmentation) -> None:
"""AutoSAM initialization.
Args:
sam (Segmentation): global Segmentation instance.
"""
self.sam = sam
self.auto_sam: Union[SamAutomaticMaskGeneratorHQ, FastSAM] = None
self.fastsam_conf = None
self.fastsam_iou = None
def auto_generate(self, img: np.ndarray) -> List[dict]:
"""Generate segmentation.
Args:
img (np.ndarray): input image.
Returns:
List[dict]: list of segmentation masks.
"""
return self.auto_sam.generate(img) if type(self.auto_sam) == SamAutomaticMaskGeneratorHQ else \
self.auto_sam(img, device=self.sam.sam_device, retina_masks=True, imgsz=1024, conf=self.fastsam_conf, iou=self.fastsam_iou)
def strengthen_semantic_seg(self, class_ids: np.ndarray, img: np.ndarray) -> np.ndarray:
# TODO: class_ids use multiple dimensions, categorical mask single and batch
logger.info("AutoSAM strengthening semantic segmentation")
from sam_utils.util import install_pycocotools
install_pycocotools()
import pycocotools.mask as maskUtils
semantc_mask = copy.deepcopy(class_ids)
annotations = self.auto_generate(img)
annotations = sorted(annotations, key=lambda x: x['area'], reverse=True)
logger.info(f"AutoSAM generated {len(annotations)} masks")
for ann in annotations:
valid_mask = torch.tensor(maskUtils.decode(ann['segmentation'])).bool()
propose_classes_ids = torch.tensor(class_ids[valid_mask])
num_class_proposals = len(torch.unique(propose_classes_ids))
if num_class_proposals == 1:
semantc_mask[valid_mask] = propose_classes_ids[0].numpy()
continue
top_1_propose_class_ids = torch.bincount(propose_classes_ids.flatten()).topk(1).indices
semantc_mask[valid_mask] = top_1_propose_class_ids.numpy()
logger.info("AutoSAM strengthening process end")
return semantc_mask
def random_segmentation(self, img: Image.Image) -> Tuple[List[Image.Image], str]:
"""Random segmentation for EditAnything
Args:
img (Image.Image): input image.
Raises:
ModuleNotFoundError: ControlNet not installed.
Returns:
Tuple[List[Image.Image], str]: List of 3 displayed images and output message.
"""
logger.info("AutoSAM generating random segmentation for EditAnything")
img_np = np.array(img.convert("RGB"))
annotations = self.auto_generate(img_np)
logger.info(f"AutoSAM generated {len(annotations)} masks")
H, W, _ = img_np.shape
color_map = np.zeros((H, W, 3), dtype=np.uint8)
detected_map_tmp = np.zeros((H, W), dtype=np.uint16)
for idx, annotation in enumerate(annotations):
current_seg = annotation['segmentation']
color_map[current_seg] = np.random.randint(0, 255, (3))
detected_map_tmp[current_seg] = idx + 1
detected_map = np.zeros((detected_map_tmp.shape[0], detected_map_tmp.shape[1], 3))
detected_map[:, :, 0] = detected_map_tmp % 256
detected_map[:, :, 1] = detected_map_tmp // 256
try:
from scripts.processor import HWC3
except:
raise ModuleNotFoundError("ControlNet extension not found.")
detected_map = HWC3(detected_map.astype(np.uint8))
logger.info("AutoSAM generation process end")
return [blend_image_and_seg(img_np, color_map), Image.fromarray(color_map), Image.fromarray(detected_map)], \
"Random segmentation done. Left above (0) is blended image, right above (1) is random segmentation, left below (2) is Edit-Anything control input."
def layout_single_image(self, input_image: Image.Image, output_path: str) -> None:
"""Single image layout generation.
Args:
input_image (Image.Image): input image.
output_path (str): output path.
"""
img_np = np.array(input_image.convert("RGB"))
annotations = self.auto_generate(img_np)
logger.info(f"AutoSAM generated {len(annotations)} annotations")
annotations = sorted(annotations, key=lambda x: x['area'])
for idx, annotation in enumerate(annotations):
img_tmp = np.zeros((img_np.shape[0], img_np.shape[1], 3))
img_tmp[annotation['segmentation']] = img_np[annotation['segmentation']]
img_np[annotation['segmentation']] = np.array([0, 0, 0])
img_tmp = Image.fromarray(img_tmp.astype(np.uint8))
img_tmp.save(os.path.join(output_path, f"{idx}.png"))
img_np = Image.fromarray(img_np.astype(np.uint8))
img_np.save(os.path.join(output_path, "leftover.png"))
def layout(self, input_image_or_path: Union[str, Image.Image], output_path: str) -> str:
"""Single or bath layout generation.
Args:
input_image_or_path (Union[str, Image.Image]): input imag or path.
output_path (str): output path.
Returns:
str: generation message.
"""
if isinstance(input_image_or_path, str):
logger.info("Image layer division batch processing")
all_files = glob.glob(os.path.join(input_image_or_path, "*"))
for image_index, input_image_file in enumerate(all_files):
logger.info(f"Processing {image_index}/{len(all_files)} {input_image_file}")
try:
input_image = Image.open(input_image_file)
output_directory = os.path.join(output_path, os.path.splitext(os.path.basename(input_image_file))[0])
from pathlib import Path
Path(output_directory).mkdir(exist_ok=True)
except:
logger.warn(f"File {input_image_file} not image, skipped.")
continue
self.layout_single_image(input_image, output_directory)
else:
self.layout_single_image(input_image_or_path, output_path)
return "Done"
def semantic_segmentation(self, input_image: Image.Image, annotator_name: str, processor_res: int,
use_pixel_perfect: bool, resize_mode: int, target_W: int, target_H: int) -> Tuple[List[Image.Image], str]:
"""Semantic segmentation enhanced by segment anything.
Args:
input_image (Image.Image): input image.
annotator_name (str): annotator name. Should be one of "seg_ufade20k"|"seg_ofade20k"|"seg_ofcoco".
processor_res (int): processor resolution. Support 64-2048.
use_pixel_perfect (bool): whether to use pixel perfect written by lllyasviel.
resize_mode (int): resize mode for pixel perfect, should be 0|1|2.
target_W (int): target width for pixel perfect.
target_H (int): target height for pixel perfect.
Raises:
ModuleNotFoundError: ControlNet not installed.
Returns:
Tuple[List[Image.Image], str]: list of 4 displayed images and message.
"""
assert input_image is not None, "No input image."
if "seg" in annotator_name:
try:
from scripts.processor import uniformer, oneformer_coco, oneformer_ade20k
from scripts.external_code import pixel_perfect_resolution
oneformers = {
"ade20k": oneformer_ade20k,
"coco": oneformer_coco
}
except:
raise ModuleNotFoundError("ControlNet extension not found.")
input_image_np = np.array(input_image)
if use_pixel_perfect:
processor_res = pixel_perfect_resolution(input_image_np, resize_mode, target_W, target_H)
logger.info("Generating semantic segmentation without SAM")
if annotator_name == "seg_ufade20k":
original_semantic = uniformer(input_image_np, processor_res)
else:
dataset = annotator_name.split('_')[-1][2:]
original_semantic = oneformers[dataset](input_image_np, processor_res)
logger.info("Generating semantic segmentation with SAM")
sam_semantic = self.strengthen_semantic_seg(np.array(original_semantic), input_image_np)
output_gallery = [original_semantic, sam_semantic, blend_image_and_seg(input_image, original_semantic), blend_image_and_seg(input_image, sam_semantic)]
return output_gallery, "Done. Left is segmentation before SAM, right is segmentation after SAM."
else:
return self.random_segmentation(input_image)
def categorical_mask_image(self, annotator_name: str, processor_res: int, category_input: List[int], input_image: Image.Image,
use_pixel_perfect: bool, resize_mode: int, target_W: int, target_H: int) -> Tuple[np.ndarray, Image.Image]:
"""Single image categorical mask.
Args:
annotator_name (str): annotator name. Should be one of "seg_ufade20k"|"seg_ofade20k"|"seg_ofcoco".
processor_res (int): processor resolution. Support 64-2048.
category_input (List[int]): category input.
input_image (Image.Image): input image.
use_pixel_perfect (bool): whether to use pixel perfect written by lllyasviel.
resize_mode (int): resize mode for pixel perfect, should be 0|1|2.
target_W (int): target width for pixel perfect.
target_H (int): target height for pixel perfect.
Raises:
ModuleNotFoundError: ControlNet not installed.
AssertionError: Illegal class id.
Returns:
Tuple[np.ndarray, Image.Image]: mask in resized shape and resized input image.
"""
assert input_image is not None, "No input image."
try:
from scripts.processor import uniformer, oneformer_coco, oneformer_ade20k
from scripts.external_code import pixel_perfect_resolution
oneformers = {
"ade20k": oneformer_ade20k,
"coco": oneformer_coco
}
except:
raise ModuleNotFoundError("ControlNet extension not found.")
filter_classes = category_input
assert len(filter_classes) > 0, "No class selected."
try:
filter_classes = [int(i) for i in filter_classes]
except:
raise AssertionError("Illegal class id. You may have input some string.")
input_image_np = np.array(input_image)
if use_pixel_perfect:
processor_res = pixel_perfect_resolution(input_image_np, resize_mode, target_W, target_H)
crop_input_image_copy = copy.deepcopy(input_image)
logger.info(f"Generating categories with processor {annotator_name}")
if annotator_name == "seg_ufade20k":
original_semantic = uniformer(input_image_np, processor_res)
else:
dataset = annotator_name.split('_')[-1][2:]
original_semantic = oneformers[dataset](input_image_np, processor_res)
sam_semantic = self.strengthen_semantic_seg(np.array(original_semantic), input_image_np)
mask = np.zeros(sam_semantic.shape, dtype=np.bool_)
from sam_utils.config import SEMANTIC_CATEGORIES
for i in filter_classes:
mask[np.equal(sam_semantic, SEMANTIC_CATEGORIES[annotator_name][i])] = True
return mask, crop_input_image_copy
def register(
self,
sam_model_name: str,
points_per_side: Optional[int] = 32,
points_per_batch: int = 64,
pred_iou_thresh: float = 0.88,
stability_score_thresh: float = 0.95,
stability_score_offset: float = 1.0,
box_nms_thresh: float = 0.7,
crop_n_layers: int = 0,
crop_nms_thresh: float = 0.7,
crop_overlap_ratio: float = 512 / 1500,
crop_n_points_downscale_factor: int = 1,
min_mask_region_area: int = 0,
output_mode: str = "binary_mask",
fastsam_conf: float = 0.4,
fastsam_iou: float = 0.9) -> None:
"""Register AutoSAM module."""
self.sam.load_sam_model(sam_model_name)
assert type(self.sam.sam_model) in [FastSAM, Sam], f"{sam_model_name} does not support auto segmentation."
if type(self.sam.sam_model) == FastSAM:
self.fastsam_conf = fastsam_conf
self.fastsam_iou = fastsam_iou
self.auto_sam = self.sam.sam_model
else:
self.auto_sam = SamAutomaticMaskGeneratorHQ(
self.sam.sam_model, points_per_side, points_per_batch, pred_iou_thresh,
stability_score_thresh, stability_score_offset, box_nms_thresh,
crop_n_layers, crop_nms_thresh, crop_overlap_ratio, crop_n_points_downscale_factor, None,
min_mask_region_area, output_mode)

View File

@ -1,22 +1,24 @@
from typing import Tuple
from typing import List
import os
import gc
from PIL import Image
import torch
from modules import shared
from modules.devices import device, torch_gc
from scripts.sam_log import logger
from modules.devices import device
from sam_utils.logger import logger
# TODO: support YOLO models
class Detection:
"""Detection related process.
"""
def __init__(self) -> None:
"""Initialize detection related process.
"""
self.dino_model = None
self.dino_model_type = ""
from scripts.sam_state import sam_extension_dir
self.dino_model_dir = os.path.join(sam_extension_dir, "models/grounding-dino")
self.dino_model_list = ["GroundingDINO_SwinT_OGC (694MB)", "GroundingDINO_SwinB (938MB)"]
self.dino_model_info = {
"GroundingDINO_SwinT_OGC (694MB)": {
"checkpoint": "groundingdino_swint_ogc.pth",
@ -29,68 +31,26 @@ class Detection:
"url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth"
},
}
self.dino_install_issue_text = "Please permanently switch to local GroundingDINO on Settings/Segment Anything or submit an issue to https://github.com/IDEA-Research/Grounded-Segment-Anything/issues."
def _install_goundingdino(self) -> bool:
if shared.opts.data.get("sam_use_local_groundingdino", False):
logger.info("Using local groundingdino.")
return False
def _load_dino_model(self, dino_checkpoint: str, use_pip_dino: bool) -> None:
"""Load GroundignDINO model to device.
def verify_dll(install_local=True):
try:
from groundingdino import _C
logger.info("GroundingDINO dynamic library have been successfully built.")
return True
except Exception:
import traceback
traceback.print_exc()
def run_pip_uninstall(command, desc=None):
from launch import python, run
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
return run(f'"{python}" -m pip uninstall -y {command}', desc=f"Uninstalling {desc}", errdesc=f"Couldn't uninstall {desc}", live=default_command_live)
if install_local:
logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and fall back to local GroundingDINO this time. {self.dino_install_issue_text}")
run_pip_uninstall(
f"groundingdino",
f"sd-webui-segment-anything requirement: groundingdino")
else:
logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and re-try installing from GitHub source code. {self.dino_install_issue_text}")
run_pip_uninstall(
f"uninstall groundingdino",
f"sd-webui-segment-anything requirement: groundingdino")
return False
import launch
if launch.is_installed("groundingdino"):
logger.info("Found GroundingDINO in pip. Verifying if dynamic library build success.")
if verify_dll(install_local=False):
return True
try:
launch.run_pip(
f"install git+https://github.com/IDEA-Research/GroundingDINO",
f"sd-webui-segment-anything requirement: groundingdino")
logger.info("GroundingDINO install success. Verifying if dynamic library build success.")
return verify_dll()
except Exception:
import traceback
traceback.print_exc()
logger.warn(f"GroundingDINO install failed. Will fall back to local groundingdino this time. {self.dino_install_issue_text}")
return False
def _load_dino_model(self, dino_checkpoint: str, dino_install_success: bool) -> torch.nn.Module:
Args:
dino_checkpoint (str): GroundingDINO checkpoint name.
use_pip_dino (bool): If True, use pip installed GroundingDINO. If False, use local GroundingDINO.
"""
logger.info(f"Initializing GroundingDINO {dino_checkpoint}")
if self.dino_model is None or dino_checkpoint != self.dino_model_type:
self.clear()
if dino_install_success:
if use_pip_dino:
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
else:
from local_groundingdino.models import build_model
from local_groundingdino.util.slconfig import SLConfig
from local_groundingdino.util.utils import clean_state_dict
from thirdparty.groundingdino.models import build_model
from thirdparty.groundingdino.util.slconfig import SLConfig
from thirdparty.groundingdino.util.utils import clean_state_dict
args = SLConfig.fromfile(self.dino_model_info[dino_checkpoint]["config"])
dino = build_model(args)
checkpoint = torch.hub.load_state_dict_from_url(
@ -103,11 +63,20 @@ class Detection:
self.dino_model.to(device=device)
def _load_dino_image(self, image_pil: Image.Image, dino_install_success: bool) -> torch.Tensor:
if dino_install_success:
def _load_dino_image(self, image_pil: Image.Image, use_pip_dino: bool) -> torch.Tensor:
"""Transform image to make the image applicable to GroundingDINO.
Args:
image_pil (Image.Image): Input image in PIL format.
use_pip_dino (bool): If True, use pip installed GroundingDINO. If False, use local GroundingDINO.
Returns:
torch.Tensor: Transformed image in torch.Tensor format.
"""
if use_pip_dino:
import groundingdino.datasets.transforms as T
else:
from local_groundingdino.datasets import transforms as T
from thirdparty.groundingdino.datasets import transforms as T
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
@ -119,7 +88,17 @@ class Detection:
return image
def _get_grounding_output(self, image: torch.Tensor, caption: str, box_threshold: float):
def _get_grounding_output(self, image: torch.Tensor, caption: str, box_threshold: float) -> torch.Tensor:
"""Inference GroundingDINO model.
Args:
image (torch.Tensor): transformed input image.
caption (str): string caption.
box_threshold (float): bbox threshold.
Returns:
torch.Tensor: generated bounding boxes.
"""
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
@ -128,7 +107,7 @@ class Detection:
with torch.no_grad():
outputs = self.dino_model(image[None], captions=[caption])
if shared.cmd_opts.lowvram:
self.dino_model.cpu()
self.unload_model()
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"][0] # (nq, 4)
@ -142,12 +121,23 @@ class Detection:
return boxes_filt.cpu()
def dino_predict(self, input_image: Image.Image, dino_model_name: str, text_prompt: str, box_threshold: float) -> Tuple[torch.Tensor, bool]:
install_success = self._install_goundingdino()
def dino_predict(self, input_image: Image.Image, dino_model_name: str, text_prompt: str, box_threshold: float) -> List[List[float]]:
"""Exposed API for GroundingDINO inference.
Args:
input_image (Image.Image): input image.
dino_model_name (str): GroundingDINO model name.
text_prompt (str): string prompt.
box_threshold (float): bbox threshold.
Returns:
List[List[float]]: generated N * xyxy bounding boxes.
"""
from sam_utils.util import install_goundingdino
install_success = install_goundingdino()
logger.info("Running GroundingDINO Inference")
dino_image = self._load_dino_image(input_image.convert("RGB"), install_success)
self._load_dino_model(dino_model_name, install_success)
using_groundingdino = install_success or shared.opts.data.get("sam_use_local_groundingdino", False)
boxes_filt = self._get_grounding_output(
dino_image, text_prompt, box_threshold
@ -158,16 +148,74 @@ class Detection:
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
gc.collect()
torch_gc()
return boxes_filt, using_groundingdino
return boxes_filt.tolist()
def check_yolo_availability(self) -> List[str]:
"""Check if YOLO models are available. Do not check if YOLO not enabled.
Returns:
List[str]: available YOLO models.
"""
if shared.opts.data.get("sam_use_local_yolo", False):
from modules.paths import models_path
sd_yolo_model_dir = os.path.join(models_path, "ultralytics")
return [name for name in os.listdir(sd_yolo_model_dir) if (".pth" in name or ".pt" in name)]
else:
return []
def yolo_predict(self, input_image: Image.Image, yolo_model_name: str, conf=0.4) -> List[List[float]]:
"""Run detection inference with models based on YOLO.
Args:
input_image (np.ndarray): input image, expect shape HW3.
conf (float, optional): object confidence threshold. Defaults to 0.4.
Raises:
RuntimeError: not getting any bbox. Might be caused by high conf or non-detection/segmentation model.
Returns:
np.ndarray: generated N * xyxy bounding boxes.
"""
from ultralytics import YOLO
assert shared.opts.data.get("sam_use_yolo_models", False), "YOLO models are not enabled. Please enable in settings/Segment Anything."
logger.info("Loading YOLO model.")
from modules.paths import models_path
sd_yolo_model_dir = os.path.join(models_path, "ultralytics")
self.dino_model = YOLO(os.path.join(sd_yolo_model_dir, yolo_model_name)).to(device)
self.dino_model_type = yolo_model_name
logger.info("Running YOLO inference.")
pred = self.dino_model(input_image, conf=conf)
bboxes = pred[0].boxes.xyxy.cpu().numpy()
if bboxes.size == 0:
error_msg = "You are not getting any bbox. There are 2 possible reasons. "\
"1. You set up a high conf which means that you should lower the conf. "\
"2. You are using a non-detection model which means that you should check your model type."
raise RuntimeError(error_msg)
else:
return bboxes.tolist()
def __call__(self, input_image: Image.Image, model_name: str, text_prompt: str, box_threshold: float, conf=0.4) -> List[List[float]]:
"""Exposed API for detection inference."""
if model_name in self.dino_model_info.keys():
return self.dino_predict(input_image, model_name, text_prompt, box_threshold)
elif model_name in self.check_yolo_availability():
return self.yolo_predict(input_image, model_name, conf)
else:
raise ValueError(f"Detection model {model_name} not found.")
def clear(self) -> None:
"""Clear detection model from any memory.
"""
del self.dino_model
self.dino_model = None
def unload_model(self) -> None:
"""Unload detection model from GPU to CPU.
"""
if self.dino_model is not None:
self.dino_model.cpu()

View File

@ -1,23 +1,27 @@
from typing import List
import gc
import os
import torch
import numpy as np
from ultralytics import YOLO
from thirdparty.fastsam import FastSAM, FastSAMPrompt
from modules import shared
from modules.safe import unsafe_torch_load, load
from modules.devices import get_device_for, cpu
from modules.devices import get_device_for, cpu, torch_gc
from modules.paths import models_path
from scripts.sam_state import sam_extension_dir
from scripts.sam_log import logger
from sam_utils.logger import logger
from sam_utils.util import ModelInfo
from sam_hq.build_sam_hq import sam_model_registry
from sam_hq.predictor import SamPredictorHQ
from mam.m2m import SamM2M
from thirdparty.sam_hq.build_sam_hq import sam_model_registry
from thirdparty.sam_hq.predictor import SamPredictorHQ
from thirdparty.mam.m2m import SamM2M
class Segmentation:
"""Segmentation related process."""
def __init__(self) -> None:
"""Initialize segmentation related process."""
self.sam_model_info = {
"sam_vit_h_4b8939.pth" : ModelInfo("SAM", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "Meta", "2.56GB"),
"sam_vit_l_0b3195.pth" : ModelInfo("SAM", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "Meta", "1.25GB"),
@ -37,7 +41,11 @@ class Segmentation:
def check_model_availability(self) -> List[str]:
# retrieve all models in all the model directories
"""retrieve all models in all the model directories
Returns:
List[str]: Model information displayed on the UI
"""
user_sam_model_dir = shared.opts.data.get("sam_model_path", "")
sd_sam_model_dir = os.path.join(models_path, "sam")
scripts_sam_model_dir = os.path.join(sam_extension_dir, "models/sam")
@ -45,7 +53,7 @@ class Segmentation:
sam_model_dirs = [sd_sam_model_dir, scripts_sam_model_dir]
if user_sam_model_dir != "":
sam_model_dirs.append(user_sam_model_dir)
if shared.opts.data.get("use_yolo_models", False):
if shared.opts.data.get("sam_use_yolo_models", False):
sam_model_dirs.append(sd_yolo_model_dir)
for dir in sam_model_dirs:
if os.path.isdir(dir):
@ -54,19 +62,27 @@ class Segmentation:
for name in sam_model_names:
if name in self.sam_model_info.keys():
self.sam_model_info[name].local_path(os.path.join(dir, name))
elif shared.opts.data.get("use_yolo_models", False):
elif shared.opts.data.get("sam_use_yolo_models", False):
logger.warn(f"Model {name} not found in support list, default to use YOLO as initializer")
self.sam_model_info[name] = ModelInfo("YOLO", os.path.join(dir, name), "?", "?", "downloaded")
return [val.get_info(key) for key, val in self.sam_model_info.items()]
def load_sam_model(self, sam_checkpoint_name: str) -> None:
"""Load segmentation model.
Args:
sam_checkpoint_name (str): The model filename. Do not change.
Raises:
RuntimeError: Model file not found in either support list or local model directory.
RuntimeError: Cannot automatically download model from remote server.
"""
sam_checkpoint_name = sam_checkpoint_name.split(" ")[0]
if self.sam_model is None or self.sam_model_name != sam_checkpoint_name:
if sam_checkpoint_name not in self.sam_model_info.keys():
error_msg = f"{sam_checkpoint_name} not found and cannot be auto downloaded"
logger.error(f"{error_msg}")
raise Exception(error_msg)
raise RuntimeError(error_msg)
if "http" in self.sam_model_info[sam_checkpoint_name].url:
sam_url = self.sam_model_info[sam_checkpoint_name].url
user_dir = shared.opts.data.get("sam_model_path", "")
@ -89,158 +105,343 @@ class Segmentation:
logger.info(f"Loading SAM model from {model_path}")
self.sam_model = sam_model_registry[sam_checkpoint_name](checkpoint=model_path)
self.sam_model_wrapper = SamPredictorHQ(self.sam_model, 'HQ' in model_type)
else:
elif "SAM" in model_type:
logger.info(f"Loading FastSAM model from {model_path}")
self.sam_model = FastSAM(model_path)
self.sam_model_wrapper = self.sam_model
elif shared.opts.data.get("sam_use_yolo_models", False):
logger.info(f"Loading YOLO model from {model_path}")
self.sam_model = YOLO(model_path)
self.sam_model_wrapper = self.sam_model
else:
error_msg = f"Unsupported model type {model_type}"
raise RuntimeError(error_msg)
self.sam_model_name = sam_checkpoint_name
torch.load = load
self.sam_model.to(self.sam_device)
def change_device(self, use_cpu: bool) -> None:
"""Change the device of the segmentation model.
Args:
use_cpu (bool): Whether to use CPU for SAM inference.
"""
self.sam_device = cpu if use_cpu else get_device_for("sam")
def __call__(self,
input_image: np.ndarray,
point_coords: List[List[int]]=None,
point_labels: List[List[int]]=None,
boxes: torch.Tensor=None,
multimask_output=True,
use_mam=False):
pass
def sam_predict(self,
input_image: np.ndarray,
positive_points: List[List[int]]=None,
negative_points: List[List[int]]=None,
boxes_coords: torch.Tensor=None,
boxes_labels: List[bool]=None,
multimask_output=True,
merge_point_and_box=True,
point_with_box=False,
use_mam=False,
use_mam_for_each_infer=False,
mam_guidance_mode: str="mask") -> np.ndarray:
input_image: np.ndarray,
positive_points: List[List[int]]=None,
negative_points: List[List[int]]=None,
positive_bbox: List[List[float]]=None,
negative_bbox: List[List[float]]=None,
merge_positive=True,
multimask_output=True,
point_with_box=False,
use_mam=False,
mam_guidance_mode: str="mask") -> np.ndarray:
"""Run segmentation inference with models based on segment anything.
Args:
input_image (np.ndarray): input image, expect shape HW3.
positive_points (List[List[int]], optional): positive point prompts. Defaults to None.
negative_points (List[List[int]], optional): negative point prompts. Defaults to None.
boxes_coords (torch.Tensor, optional): bbox inputs, expect shape xyxy. Defaults to None.
boxes_labels (List[bool], optional): bbox labels, support positive & negative bboxes. Defaults to None.
positive_points (List[List[int]], optional): positive point prompts, expect N * xy. Defaults to None.
negative_points (List[List[int]], optional): negative point prompts, expect N * xy. Defaults to None.
positive_bbox (List[List[float]], optional): positive bbox prompts, expect N * xyxy. Defaults to None.
negative_bbox (List[List[float]], optional): negative bbox prompts, expect N * xyxy. Defaults to None.
merge_positive (bool, optional): OR all positive masks. Defaults to True. Valid only if point_with_box is False.
multimask_output (bool, optional): output 3 masks or not. Defaults to True.
merge_point_and_box (bool, optional): if True, output point masks || bbox masks; otherwise, output point masks && bbox masks. Valid only if point_with_box is False. Defaults to True.
point_with_box (bool, optional): always send bboxes and points to the model at the same time. Defaults to False.
use_mam (bool, optional): use Matting-Anything. Defaults to False.
use_mam_for_each_infer (bool, optional): use Matting-Anything for each SAM inference. Valid only if use_mam is True. Defaults to False.
mam_guidance_mode (str, optional): guidance model for Matting-Anything. Expect "mask" or "bbox". Valid only if use_mam is True. Defaults to "mask".
mam_guidance_mode (str, optional): guidance model for Matting-Anything. Expect "mask" or "bbox". Defaults to "mask". Valid only if use_mam is True.
Returns:
np.ndarray: mask output, expect shape 11HW or 31HW.
"""
masks_for_points, masks_for_boxes, masks = None, None, None
has_points = positive_points is not None or negative_points is not None
has_bbox = positive_bbox is not None or negative_bbox is not None
assert has_points or has_bbox, "No input provided. Please provide at least one point or one bbox."
assert type(self.sam_model_wrapper) == SamPredictorHQ, "Incorrect SAM model wrapper. Expect SamPredictorHQ here."
mask_shape = ((3, 1) if multimask_output else (1, 1)) + input_image.shape[:2]
self.sam_model_wrapper.set_image(input_image)
# If use Matting-Anything, load mam model.
if use_mam:
self.sam_m2m.load_m2m() # TODO: raise exception if network problem
# When always send bboxes and points to the model at the same time.
if point_coords is not None and point_labels is not None and boxes_coords is not None and point_with_box:
try:
self.sam_m2m.load_m2m()
except:
use_mam = False
# Matting-Anything inference for each SAM inference.
def _mam_infer(mask: np.ndarray, low_res_mask: np.ndarray) -> np.ndarray:
low_res_mask_logits = low_res_mask > self.sam_model_wrapper.model.mask_threshold
if use_mam:
mask = self.sam_m2m.forward(
self.sam_model_wrapper.features, torch.tensor(input_image), low_res_mask_logits, mask,
self.sam_model_wrapper.original_size, self.sam_model_wrapper.input_size, mam_guidance_mode)
return mask
# When always send bboxes and points to SAM at the same time.
if has_points and has_bbox and point_with_box:
logger.info(f"SAM {self.sam_model_name} inference with "
f"{len(positive_points)} positive points, {len(negative_points)} negative points, "
f"{len(positive_bbox)} positive bboxes, {len(negative_bbox)} negative bboxes. "
f"For each bbox, all point prompts are affective. Masks for each bbox will be merged.")
point_coords = np.array(positive_points + negative_points)
point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
point_labels_neg = np.array([0] * len(positive_points) + [1] * len(negative_points))
positive_masks, negative_masks, positive_low_res_masks, negative_low_res_masks = [], [], [], []
# Inference for each positive bbox.
for box in boxes_coords[boxes_labels]:
mask, _, low_res_mask = self.sam_model_wrapper.predict(
point_coords=point_coords,
point_labels=point_labels,
box=box.numpy(),
multimask_output=multimask_output)
mask = mask[:, None, ...]
low_res_mask = low_res_mask[:, None, ...]
if use_mam:
low_res_mask_logits = low_res_mask > self.sam_model_wrapper.model.mask_threshold
if use_mam_for_each_infer:
mask = self.sam_m2m.forward(
self.sam_model_wrapper.features, torch.tensor(input_image), low_res_mask_logits, mask,
self.sam_model_wrapper.original_size, self.sam_model_wrapper.input_size, mam_guidance_mode)
else:
positive_low_res_masks.append(low_res_mask_logits)
positive_masks.append(mask)
positive_masks = np.logical_or(np.stack(positive_masks, 0))
# Inference for each negative bbox.
for box in boxes_coords[[not i for i in boxes_labels]]:
mask, _, low_res_mask = self.sam_model_wrapper.predict(
point_coords=point_coords,
point_labels=point_labels_neg,
box=box.numpy(),
multimask_output=multimask_output)
mask = mask[:, None, ...]
low_res_mask = low_res_mask[:, None, ...]
if use_mam:
low_res_mask_logits = low_res_mask > self.sam_model_wrapper.model.mask_threshold
if use_mam_for_each_infer:
mask = self.sam_m2m.forward(
self.sam_model_wrapper.features, torch.tensor(input_image), low_res_mask_logits, mask,
self.sam_model_wrapper.original_size, self.sam_model_wrapper.input_size, mam_guidance_mode)
else:
negative_low_res_masks.append(low_res_mask_logits)
negative_masks.append(mask)
negative_masks = np.logical_or(np.stack(negative_masks, 0))
masks = np.logical_and(positive_masks, ~negative_masks)
# Matting-Anything inference if not for each inference.
if use_mam and not use_mam_for_each_infer:
positive_low_res_masks = np.logical_or(np.stack(positive_low_res_masks, 0))
negative_low_res_masks = np.logical_or(np.stack(negative_low_res_masks, 0))
low_res_masks = np.logical_and(positive_low_res_masks, ~negative_low_res_masks)
masks = self.sam_m2m.forward(
self.sam_model_wrapper.features, torch.tensor(input_image), low_res_masks, masks,
self.sam_model_wrapper.original_size, self.sam_model_wrapper.input_size, mam_guidance_mode)
return masks
# When separate bbox inference from point inference.
if point_coords is not None and point_labels is not None:
def _box_infer(_bbox, _point_labels):
_masks = []
for box in _bbox:
mask, _, low_res_mask = self.sam_model_wrapper.predict(
point_coords=point_coords,
point_labels=_point_labels,
box=np.array(box),
multimask_output=multimask_output)
mask = mask[:, None, ...]
low_res_mask = low_res_mask[:, None, ...]
_masks.append(_mam_infer(mask, low_res_mask))
return np.logical_or.reduce(_masks)
mask_bbox_positive = _box_infer(positive_bbox, point_labels) if positive_bbox is not None else np.ones(shape=mask_shape, dtype=np.bool_)
mask_bbox_negative = _box_infer(negative_bbox, point_labels_neg) if negative_bbox is not None else np.zeros(shape=mask_shape, dtype=np.bool_)
return mask_bbox_positive & ~mask_bbox_negative
# When separate bbox inference from point inference.
if has_points:
logger.info(f"SAM {self.sam_model_name} inference with "
f"{len(positive_points)} positive points, {len(negative_points)} negative points")
point_coords = np.array(positive_points + negative_points)
point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
masks_for_points, _, _ = self.sam_model_wrapper.predict(
mask_points_positive, _, low_res_masks_points_positive = self.sam_model_wrapper.predict(
point_coords=point_coords,
point_labels=point_labels,
box=None,
multimask_output=multimask_output)
masks_for_points = masks_for_points[:, None, ...]
# TODO: m2m
if boxes_coords is not None:
transformed_boxes = self.sam_model_wrapper.transform.apply_boxes_torch(boxes_coords, input_image.shape[:2])
masks_for_boxes, _, _ = self.sam_model_wrapper.predict_torch(
mask_points_positive = mask_points_positive[:, None, ...]
low_res_masks_points_positive = low_res_masks_points_positive[:, None, ...]
mask_points_positive = _mam_infer(mask_points_positive, low_res_masks_points_positive)
else:
mask_points_positive = np.ones(shape=mask_shape, dtype=np.bool_)
def _box_infer(_bbox, _character):
logger.info(f"SAM {self.sam_model_name} inference with {len(positive_bbox)} {_character} bboxes")
transformed_boxes = self.sam_model_wrapper.transform.apply_boxes_torch(torch.tensor(_bbox), input_image.shape[:2])
mask, _, low_res_mask = self.sam_model_wrapper.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes.to(self.sam_device),
multimask_output=multimask_output)
masks_for_boxes = masks_for_boxes.permute(1, 0, 2, 3).cpu().numpy()
# TODO: m2m
if masks_for_boxes is not None and masks_for_points is not None:
if merge_point_and_box:
masks = np.logical_or(masks_for_points, masks_for_boxes)
else:
masks = np.logical_and(masks_for_points, masks_for_boxes)
elif masks_for_boxes is not None:
masks = masks_for_boxes
elif masks_for_points is not None:
masks = masks_for_points
# TODO: m2m
return masks
mask = mask.permute(1, 0, 2, 3).cpu().numpy()
low_res_mask = low_res_mask.permute(1, 0, 2, 3).cpu().numpy()
return _mam_infer(mask, low_res_mask)
mask_bbox_positive = _box_infer(positive_bbox, "positive") if positive_bbox is not None else np.ones(shape=mask_shape, dtype=np.bool_)
mask_bbox_negative = _box_infer(negative_bbox, "negative") if negative_bbox is not None else np.zeros(shape=mask_shape, dtype=np.bool_)
if merge_positive:
return (mask_points_positive | mask_bbox_positive) & ~mask_bbox_negative
else:
return (mask_points_positive & mask_bbox_positive) & ~mask_bbox_negative
def yolo_predict():
pass
def fastsam_predict(self,
input_image: np.ndarray,
positive_points: List[List[int]]=None,
negative_points: List[List[int]]=None,
positive_bbox: List[List[float]]=None,
negative_bbox: List[List[float]]=None,
positive_text: str="",
negative_text: str="",
merge_positive=True,
merge_negative=True,
conf=0.4, iou=0.9,) -> np.ndarray:
"""Run segmentation inference with models based on FastSAM. (This is a special kind of YOLO model)
Args:
input_image (np.ndarray): input image, expect shape HW3.
positive_points (List[List[int]], optional): positive point prompts, expect N * xy Defaults to None.
negative_points (List[List[int]], optional): negative point prompts, expect N * xy Defaults to None.
positive_bbox (List[List[float]], optional): positive bbox prompts, expect N * xyxy. Defaults to None.
negative_bbox (List[List[float]], optional): negative bbox prompts, expect N * xyxy. Defaults to None.
positive_text (str, optional): positive text prompts. Defaults to "".
negative_text (str, optional): negative text prompts. Defaults to "".
merge_positive (bool, optional): OR all positive masks. Defaults to True.
merge_negative (bool, optional): OR all negative masks. Defaults to True.
conf (float, optional): object confidence threshold. Defaults to 0.4.
iou (float, optional): iou threshold for filtering the annotations. Defaults to 0.9.
Returns:
np.ndarray: mask output, expect shape 11HW. FastSAM does not support multi-mask selection.
"""
assert type(self.sam_model_wrapper) == FastSAM, "Incorrect SAM model wrapper. Expect FastSAM here."
logger.info(f"Running FastSAM {self.sam_model_name} inference.")
annotation = self.sam_model_wrapper(
input_image, device=self.sam_device, retina_masks=True, imgsz=1024, conf=conf, iou=iou)
has_points = positive_points is not None or negative_points is not None
has_bbox = positive_bbox is not None or negative_bbox is not None
has_text = positive_text != "" or negative_text != ""
assert has_points or has_bbox or has_text, "No input provided. Please provide at least one point or one bbox or one text."
logger.info("Post-processing FastSAM inference.")
prompt_process = FastSAMPrompt(input_image, annotation, device=self.sam_device)
mask_shape = (1, 1) + input_image.shape[:2]
mask_bbox_positive = prompt_process.box_prompt(bboxes=positive_bbox) if positive_bbox is not None else np.ones(shape=mask_shape, dtype=np.bool_)
mask_bbox_negative = prompt_process.box_prompt(bboxes=negative_bbox) if negative_bbox is not None else np.zeros(shape=mask_shape, dtype=np.bool_)
mask_text_positive = prompt_process.text_prompt(text=positive_text) if positive_text != "" else np.ones(shape=mask_shape, dtype=np.bool_)
mask_text_negative = prompt_process.text_prompt(text=negative_text) if negative_text != "" else np.zeros(shape=mask_shape, dtype=np.bool_)
point_coords = np.array(positive_points + negative_points)
point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
mask_points_positive = prompt_process.point_prompt(points=point_coords, pointlabel=point_labels) if has_points else np.ones(shape=mask_shape, dtype=np.bool_)
if merge_positive:
mask_positive = mask_bbox_positive | mask_text_positive | mask_points_positive
else:
mask_positive = mask_bbox_positive & mask_text_positive & mask_points_positive
if merge_negative:
mask_negative = mask_bbox_negative | mask_text_negative
else:
mask_negative = mask_bbox_negative & mask_text_negative
return mask_positive & ~mask_negative
def yolo_predict(self, input_image: np.ndarray, conf=0.4) -> np.ndarray:
"""Run segmentation inference with models based on YOLO.
Args:
input_image (np.ndarray): input image, expect shape HW3.
conf (float, optional): object confidence threshold. Defaults to 0.4.
Raises:
RuntimeError: not getting any bbox. Might be caused by high conf or non-detection/segmentation model.
Returns:
np.ndarray: mask output, expect shape 11HW. YOLO does not support multi-mask selection.
"""
assert shared.opts.data.get("sam_use_yolo_models", False), "YOLO models are not enabled. Please enable in settings/Segment Anything."
assert type(self.sam_model_wrapper) == YOLO, "Incorrect SAM model wrapper. Expect YOLO here."
logger.info("Running YOLO inference.")
pred = self.sam_model_wrapper(input_image, conf=conf)
bboxes = pred[0].boxes.xyxy.cpu().numpy()
if bboxes.size == 0:
error_msg = "You are not getting any bbox. There are 2 possible reasons. "\
"1. You set up a high conf which means that you should lower the conf. "\
"2. You are using a non-detection/segmentation model which means that you should check your model type."
logger.error(error_msg)
raise RuntimeError(error_msg)
if pred[0].masks is None:
logger.warn("You are not using a segmentation model. Will use bbox to create masks.")
masks = []
for bbox in bboxes:
mask_shape = (1, 1) + input_image.shape[:2]
mask = np.zeros(mask_shape, dtype=bool)
x1, y1, x2, y2 = bbox
mask[:, :, y1:y2+1, x1:x2+1] = True
return np.logical_or.reduce(masks, axis=0)
else:
return np.logical_or.reduce(pred[0].masks.data, axis=0)
def clear(self):
"""Clear segmentation model from CPU & GPU."""
del self.sam_model
self.sam_model = None
self.sam_model_name = ""
self.sam_model_wrapper = None
self.sam_m2m.clear()
def unload_model(self):
"""Move all segmentation models to CPU."""
if self.sam_model is not None:
self.sam_model.cpu()
self.sam_m2m.unload_model()
def __call__(self,
sam_checkpoint_name: str,
input_image: np.ndarray,
positive_points: List[List[int]]=None,
negative_points: List[List[int]]=None,
positive_bbox: List[List[float]]=None,
negative_bbox: List[List[float]]=None,
positive_text: str="",
negative_text: str="",
merge_positive=True,
merge_negative=True,
multimask_output=True,
point_with_box=False,
use_mam=False,
mam_guidance_mode: str="mask",
conf=0.4, iou=0.9,) -> np.ndarray:
# use_cpu: bool=False,) -> np.ndarray:
"""Entry for segmentation inference. Load model, run inference, unload model if lowvram.
Args:
sam_checkpoint_name (str): The model filename. Do not change.
input_image (np.ndarray): input image, expect shape HW3.
positive_points (List[List[int]], optional): positive point prompts, expect N * xy. Defaults to None. Valid for SAM & FastSAM.
negative_points (List[List[int]], optional): negative point prompts, expect N * xy. Defaults to None. Valid for SAM & FastSAM.
positive_bbox (List[List[float]], optional): positive bbox prompts, expect N * xyxy. Defaults to None. Valid for SAM & FastSAM.
negative_bbox (List[List[float]], optional): negative bbox prompts, expect N * xyxy. Defaults to None. Valid for SAM & FastSAM.
positive_text (str, optional): positive text prompts. Defaults to "". Valid for FastSAM.
negative_text (str, optional): negative text prompts. Defaults to "". Valid for FastSAM.
merge_positive (bool, optional): OR all positive masks. Defaults to True. Valid for SAM (point_with_box is True) & FastSAM.
merge_negative (bool, optional): OR all negative masks. Defaults to True. Valid for FastSAM.
multimask_output (bool, optional): output 3 masks or not. Defaults to True. Valid for SAM.
point_with_box (bool, optional): always send bboxes and points to the model at the same time. Defaults to False. Valid for SAM.
use_mam (bool, optional): use Matting-Anything. Defaults to False. Valid for SAM.
mam_guidance_mode (str, optional): guidance model for Matting-Anything. Expect "mask" or "bbox". Defaults to "mask". Valid for SAM and use_mam is True.
conf (float, optional): object confidence threshold. Defaults to 0.4. Valid for FastSAM & YOLO.
iou (float, optional): iou threshold for filtering the annotations. Defaults to 0.9. Valid for FastSAM.
use_cpu (bool, optional): use CPU for SAM inference. Defaults to False.
Returns:
np.ndarray: _description_
"""
# self.change_device(use_cpu)
self.load_sam_model(sam_checkpoint_name)
if type(self.sam_model_wrapper) == SamPredictorHQ:
masks = self.sam_predict(
input_image=input_image,
positive_points=positive_points,
negative_points=negative_points,
positive_bbox=positive_bbox,
negative_bbox=negative_bbox,
merge_positive=merge_positive,
multimask_output=multimask_output,
point_with_box=point_with_box,
use_mam=use_mam,
mam_guidance_mode=mam_guidance_mode)
elif type(self.sam_model_wrapper) == FastSAM:
masks = self.fastsam_predict(
input_image=input_image,
positive_points=positive_points,
negative_points=negative_points,
positive_bbox=positive_bbox,
negative_bbox=negative_bbox,
positive_text=positive_text,
negative_text=negative_text,
merge_positive=merge_positive,
merge_negative=merge_negative,
conf=conf, iou=iou)
else:
masks = self.yolo_predict(
input_image=input_image,
conf=conf)
if shared.cmd_opts.lowvram:
self.unload_model()
gc.collect()
torch_gc()
return masks
# check device for each model type
# how to use yolo for auto
# category name dropdown
# category name dropdown and dynamic ui
# yolo model for segmentation and detection
# zoom in and unify box+point
# zoom in, unify box+point
# make masks smaller

View File

@ -3,7 +3,6 @@ import os
import numpy as np
import cv2
import copy
from scipy.ndimage import binary_dilation
from PIL import Image
@ -51,6 +50,7 @@ def show_masks(image_np: np.ndarray, masks: np.ndarray, alpha=0.5) -> np.ndarray
def dilate_mask(mask: np.ndarray, dilation_amt: int) -> Tuple[Image.Image, np.ndarray]:
from scipy.ndimage import binary_dilation
x, y = np.meshgrid(np.arange(dilation_amt), np.arange(dilation_amt))
center = dilation_amt // 2
dilation_kernel = ((x - center)**2 + (y - center)**2 <= center**2).astype(np.uint8)
@ -110,3 +110,95 @@ def create_mask_batch_output(
if save_image_with_mask:
output_blend = Image.fromarray(blended_image)
output_blend.save(os.path.join(dest_dir, f"{filename}_{idx}_blend{ext}"))
def blend_image_and_seg(image: np.ndarray, seg: np.ndarray, alpha=0.5) -> Image.Image:
image_blend = image * (1 - alpha) + np.array(seg) * alpha
return Image.fromarray(image_blend.astype(np.uint8))
def install_pycocotools():
# install pycocotools if needed
from sam_utils.logger import logger
try:
import pycocotools.mask as maskUtils
except:
logger.warn("pycocotools not found, will try installing C++ based pycocotools")
try:
from launch import run_pip
run_pip(f"install pycocotools", f"AutoSAM requirement: pycocotools")
import pycocotools.mask as maskUtils
except:
import traceback
traceback.print_exc()
import sys
if sys.platform == "win32":
logger.warn("Unable to install pycocotools, will try installing pycocotools-windows")
try:
run_pip("install pycocotools-windows", "AutoSAM requirement: pycocotools-windows")
import pycocotools.mask as maskUtils
except:
error_msg = "Unable to install pycocotools-windows"
logger.error(error_msg)
traceback.print_exc()
raise RuntimeError(error_msg)
else:
error_msg = "Unable to install pycocotools"
logger.error(error_msg)
traceback.print_exc()
raise RuntimeError(error_msg)
def install_goundingdino() -> bool:
"""Automatically install GroundingDINO.
Returns:
bool: False if use local GroundingDINO, True if use pip installed GroundingDINO.
"""
from sam_utils.logger import logger
from modules import shared
dino_install_issue_text = "Please permanently switch to local GroundingDINO on Settings/Segment Anything or submit an issue to https://github.com/IDEA-Research/Grounded-Segment-Anything/issues."
if shared.opts.data.get("sam_use_local_groundingdino", False):
logger.info("Using local groundingdino.")
return False
def verify_dll(install_local=True):
try:
from groundingdino import _C
logger.info("GroundingDINO dynamic library have been successfully built.")
return True
except Exception:
import traceback
traceback.print_exc()
def run_pip_uninstall(command, desc=None):
from launch import python, run
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
return run(f'"{python}" -m pip uninstall -y {command}', desc=f"Uninstalling {desc}", errdesc=f"Couldn't uninstall {desc}", live=default_command_live)
if install_local:
logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and fall back to local GroundingDINO this time. {dino_install_issue_text}")
run_pip_uninstall(
f"groundingdino",
f"sd-webui-segment-anything requirement: groundingdino")
else:
logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and re-try installing from GitHub source code. {dino_install_issue_text}")
run_pip_uninstall(
f"uninstall groundingdino",
f"sd-webui-segment-anything requirement: groundingdino")
return False
import launch
if launch.is_installed("groundingdino"):
logger.info("Found GroundingDINO in pip. Verifying if dynamic library build success.")
if verify_dll(install_local=False):
return True
try:
launch.run_pip(
f"install git+https://github.com/IDEA-Research/GroundingDINO",
f"sd-webui-segment-anything requirement: groundingdino")
logger.info("GroundingDINO install success. Verifying if dynamic library build success.")
return verify_dll()
except Exception:
import traceback
traceback.print_exc()
logger.warn(f"GroundingDINO install failed. Will fall back to local groundingdino this time. {dino_install_issue_text}")
return False

View File

@ -1,183 +0,0 @@
from typing import List, Tuple
import os
import glob
import copy
from PIL import Image
import numpy as np
import torch
from sam_hq.automatic import SamAutomaticMaskGeneratorHQ
from scripts.sam_log import logger
class AutoSAM:
def __init__(self) -> None:
self.auto_sam: SamAutomaticMaskGeneratorHQ = None
def blend_image_and_seg(self, image: np.ndarray, seg: np.ndarray, alpha=0.5) -> Image.Image:
image_blend = image * (1 - alpha) + np.array(seg) * alpha
return Image.fromarray(image_blend.astype(np.uint8))
def strengthen_semmantic_seg(self, class_ids: np.ndarray, img: np.ndarray) -> np.ndarray:
logger.info("AutoSAM strengthening semantic segmentation")
import pycocotools.mask as maskUtils
semantc_mask = copy.deepcopy(class_ids)
annotations = self.auto_sam.generate(img)
annotations = sorted(annotations, key=lambda x: x['area'], reverse=True)
logger.info(f"AutoSAM generated {len(annotations)} masks")
for ann in annotations:
valid_mask = torch.tensor(maskUtils.decode(ann['segmentation'])).bool()
propose_classes_ids = torch.tensor(class_ids[valid_mask])
num_class_proposals = len(torch.unique(propose_classes_ids))
if num_class_proposals == 1:
semantc_mask[valid_mask] = propose_classes_ids[0].numpy()
continue
top_1_propose_class_ids = torch.bincount(propose_classes_ids.flatten()).topk(1).indices
semantc_mask[valid_mask] = top_1_propose_class_ids.numpy()
logger.info("AutoSAM strengthening process end")
return semantc_mask
def random_segmentation(self, img: Image.Image) -> Tuple[List[Image.Image], str]:
logger.info("AutoSAM generating random segmentation for EditAnything")
img_np = np.array(img.convert("RGB"))
annotations = self.auto_sam.generate(img_np)
logger.info(f"AutoSAM generated {len(annotations)} masks")
H, W, _ = img_np.shape
color_map = np.zeros((H, W, 3), dtype=np.uint8)
detected_map_tmp = np.zeros((H, W), dtype=np.uint16)
for idx, annotation in enumerate(annotations):
current_seg = annotation['segmentation']
color_map[current_seg] = np.random.randint(0, 255, (3))
detected_map_tmp[current_seg] = idx + 1
detected_map = np.zeros((detected_map_tmp.shape[0], detected_map_tmp.shape[1], 3))
detected_map[:, :, 0] = detected_map_tmp % 256
detected_map[:, :, 1] = detected_map_tmp // 256
try:
from scripts.processor import HWC3
except:
return [], "ControlNet extension not found."
detected_map = HWC3(detected_map.astype(np.uint8))
logger.info("AutoSAM generation process end")
return [self.blend_image_and_seg(img_np, color_map), Image.fromarray(color_map), Image.fromarray(detected_map)], \
"Random segmentation done. Left above (0) is blended image, right above (1) is random segmentation, left below (2) is Edit-Anything control input."
def layer_single_image(self, layout_input_image: Image.Image, layout_output_path: str) -> None:
img_np = np.array(layout_input_image.convert("RGB"))
annotations = self.auto_sam.generate(img_np)
logger.info(f"AutoSAM generated {len(annotations)} annotations")
annotations = sorted(annotations, key=lambda x: x['area'])
for idx, annotation in enumerate(annotations):
img_tmp = np.zeros((img_np.shape[0], img_np.shape[1], 3))
img_tmp[annotation['segmentation']] = img_np[annotation['segmentation']]
img_np[annotation['segmentation']] = np.array([0, 0, 0])
img_tmp = Image.fromarray(img_tmp.astype(np.uint8))
img_tmp.save(os.path.join(layout_output_path, f"{idx}.png"))
img_np = Image.fromarray(img_np.astype(np.uint8))
img_np.save(os.path.join(layout_output_path, "leftover.png"))
def image_layer(self, layout_input_image_or_path, layout_output_path: str) -> str:
if isinstance(layout_input_image_or_path, str):
logger.info("Image layer division batch processing")
all_files = glob.glob(os.path.join(layout_input_image_or_path, "*"))
for image_index, input_image_file in enumerate(all_files):
logger.info(f"Processing {image_index}/{len(all_files)} {input_image_file}")
try:
input_image = Image.open(input_image_file)
output_directory = os.path.join(layout_output_path, os.path.splitext(os.path.basename(input_image_file))[0])
from pathlib import Path
Path(output_directory).mkdir(exist_ok=True)
except:
logger.warn(f"File {input_image_file} not image, skipped.")
continue
self.layer_single_image(input_image, output_directory)
else:
self.layer_single_image(layout_input_image_or_path, layout_output_path)
return "Done"
def semantic_segmentation(self, input_image: Image.Image, annotator_name: str, processor_res: int,
use_pixel_perfect: bool, resize_mode: int, target_W: int, target_H: int) -> Tuple[List[Image.Image], str]:
if input_image is None:
return [], "No input image."
if "seg" in annotator_name:
try:
from scripts.processor import uniformer, oneformer_coco, oneformer_ade20k
from scripts.external_code import pixel_perfect_resolution
oneformers = {
"ade20k": oneformer_ade20k,
"coco": oneformer_coco
}
except:
return [], "ControlNet extension not found."
input_image_np = np.array(input_image)
if use_pixel_perfect:
processor_res = pixel_perfect_resolution(input_image_np, resize_mode, target_W, target_H)
logger.info("Generating semantic segmentation without SAM")
if annotator_name == "seg_ufade20k":
original_semantic = uniformer(input_image_np, processor_res)
else:
dataset = annotator_name.split('_')[-1][2:]
original_semantic = oneformers[dataset](input_image_np, processor_res)
logger.info("Generating semantic segmentation with SAM")
sam_semantic = self.strengthen_semmantic_seg(np.array(original_semantic), input_image_np)
output_gallery = [original_semantic, sam_semantic, self.blend_image_and_seg(input_image, original_semantic), self.blend_image_and_seg(input_image, sam_semantic)]
return output_gallery, f"Done. Left is segmentation before SAM, right is segmentation after SAM."
else:
return self.random_segmentation(input_image)
def categorical_mask_image(self, crop_processor: str, crop_processor_res: int, crop_category_input: List[int], crop_input_image: Image.Image,
crop_pixel_perfect: bool, crop_resize_mode: int, target_W: int, target_H: int) -> Tuple[np.ndarray, Image.Image]:
if crop_input_image is None:
return "No input image."
try:
from scripts.processor import uniformer, oneformer_coco, oneformer_ade20k
from scripts.external_code import pixel_perfect_resolution
oneformers = {
"ade20k": oneformer_ade20k,
"coco": oneformer_coco
}
except:
return [], "ControlNet extension not found."
filter_classes = crop_category_input
if len(filter_classes) == 0:
return "No class selected."
try:
filter_classes = [int(i) for i in filter_classes]
except:
return "Illegal class id. You may have input some string."
crop_input_image_np = np.array(crop_input_image)
if crop_pixel_perfect:
crop_processor_res = pixel_perfect_resolution(crop_input_image_np, crop_resize_mode, target_W, target_H)
crop_input_image_copy = copy.deepcopy(crop_input_image)
logger.info(f"Generating categories with processor {crop_processor}")
if crop_processor == "seg_ufade20k":
original_semantic = uniformer(crop_input_image_np, crop_processor_res)
else:
dataset = crop_processor.split('_')[-1][2:]
original_semantic = oneformers[dataset](crop_input_image_np, crop_processor_res)
sam_semantic = self.strengthen_semmantic_seg(np.array(original_semantic), crop_input_image_np)
mask = np.zeros(sam_semantic.shape, dtype=np.bool_)
from scripts.sam_config import SEMANTIC_CATEGORIES
for i in filter_classes:
mask[np.equal(sam_semantic, SEMANTIC_CATEGORIES[crop_processor][i])] = True
return mask, crop_input_image_copy
def register_auto_sam(self, sam,
auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh,
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, auto_sam_output_mode):
self.auto_sam = SamAutomaticMaskGeneratorHQ(
sam, auto_sam_points_per_side, auto_sam_points_per_batch,
auto_sam_pred_iou_thresh, auto_sam_stability_score_thresh,
auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
auto_sam_crop_n_points_downscale_factor, None,
auto_sam_min_mask_region_area, auto_sam_output_mode)

9
thirdparty/fastsam/__init__.py vendored Normal file
View File

@ -0,0 +1,9 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .model import FastSAM
from .predict import FastSAMPredictor
from .prompt import FastSAMPrompt
# from .val import FastSAMValidator
from .decoder import FastSAMDecoder
__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMDecoder'

131
thirdparty/fastsam/decoder.py vendored Normal file
View File

@ -0,0 +1,131 @@
from .model import FastSAM
import numpy as np
from PIL import Image
import clip
from typing import Optional, List, Tuple, Union
class FastSAMDecoder:
def __init__(
self,
model: FastSAM,
device: str='cpu',
conf: float=0.4,
iou: float=0.9,
imgsz: int=1024,
retina_masks: bool=True,
):
self.model = model
self.device = device
self.retina_masks = retina_masks
self.imgsz = imgsz
self.conf = conf
self.iou = iou
self.image = None
self.image_embedding = None
def run_encoder(self, image):
if isinstance(image,str):
image = np.array(Image.open(image))
self.image = image
image_embedding = self.model(
self.image,
device=self.device,
retina_masks=self.retina_masks,
imgsz=self.imgsz,
conf=self.conf,
iou=self.iou
)
return image_embedding[0].numpy()
def run_decoder(
self,
image_embedding,
point_prompt: Optional[np.ndarray]=None,
point_label: Optional[np.ndarray]=None,
box_prompt: Optional[np.ndarray]=None,
text_prompt: Optional[str]=None,
)->np.ndarray:
self.image_embedding = image_embedding
if point_prompt is not None:
ann = self.point_prompt(points=point_prompt, pointlabel=point_label)
return ann
elif box_prompt is not None:
ann = self.box_prompt(bbox=box_prompt)
return ann
elif text_prompt is not None:
ann = self.text_prompt(text=text_prompt)
return ann
else:
return None
def box_prompt(self, bbox):
assert (bbox[2] != 0 and bbox[3] != 0)
masks = self.image_embedding.masks.data
target_height = self.ori_img.shape[0]
target_width = self.ori_img.shape[1]
h = masks.shape[1]
w = masks.shape[2]
if h != target_height or w != target_width:
bbox = [
int(bbox[0] * w / target_width),
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height), ]
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2))
orig_masks_area = np.sum(masks, axis=(1, 2))
union = bbox_area + orig_masks_area - masks_area
IoUs = masks_area / union
max_iou_index = np.argmax(IoUs)
return np.array([masks[max_iou_index].cpu().numpy()])
def point_prompt(self, points, pointlabel): # numpy
masks = self._format_results(self.results[0], 0)
target_height = self.ori_img.shape[0]
target_width = self.ori_img.shape[1]
h = masks[0]['segmentation'].shape[0]
w = masks[0]['segmentation'].shape[1]
if h != target_height or w != target_width:
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
onemask = np.zeros((h, w))
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
for i, annotation in enumerate(masks):
if type(annotation) == dict:
mask = annotation['segmentation']
else:
mask = annotation
for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
onemask[mask] = 1
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
onemask[mask] = 0
onemask = onemask >= 1
return np.array([onemask])
def _format_results(self, result, filter=0):
annotations = []
n = len(result.masks.data)
for i in range(n):
annotation = {}
mask = result.masks.data[i] == 1.0
if np.sum(mask) < filter:
continue
annotation['id'] = i
annotation['segmentation'] = mask
annotation['bbox'] = result.boxes.data[i]
annotation['score'] = result.boxes.conf[i]
annotation['area'] = annotation['segmentation'].sum()
annotations.append(annotation)
return annotations

104
thirdparty/fastsam/model.py vendored Normal file
View File

@ -0,0 +1,104 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
YOLO-NAS model interface.
Usage - Predict:
from ultralytics import FastSAM
model = FastSAM('last.pt')
results = model.predict('ultralytics/assets/bus.jpg')
"""
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode
from .predict import FastSAMPredictor
class FastSAM(YOLO):
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = self.overrides.copy()
overrides['conf'] = 0.25
overrides.update(kwargs) # prefer kwargs
overrides['mode'] = kwargs.get('mode', 'predict')
assert overrides['mode'] in ['track', 'predict']
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
self.predictor = FastSAMPredictor(overrides=overrides)
self.predictor.setup_model(model=self.model, verbose=False)
return self.predictor(source, stream=stream)
def train(self, **kwargs):
"""Function trains models but raises an error as FastSAM models do not support training."""
raise NotImplementedError("FastSAM models don't support training")
def val(self, **kwargs):
"""Run validation given dataset."""
overrides = dict(task='segment', mode='val')
overrides.update(kwargs) # prefer kwargs
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = FastSAM(args=args)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
"""
overrides = dict(task='detect')
overrides.update(kwargs)
overrides['mode'] = 'export'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")

52
thirdparty/fastsam/predict.py vendored Normal file
View File

@ -0,0 +1,52 @@
import torch
from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import DEFAULT_CFG, ops
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
from .utils import bbox_iou
class FastSAMPredictor(DetectionPredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'segment'
def postprocess(self, preds, img, orig_imgs):
"""TODO: filter by classes."""
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=len(self.model.names),
classes=self.args.classes)
full_box = torch.zeros_like(p[0][0])
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
full_box = full_box.view(1, -1)
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
if critical_iou_index.numel() != 0:
full_box[0][4] = p[0][critical_iou_index][:,4]
full_box[0][6:] = p[0][critical_iou_index][:,6:]
p[0][critical_iou_index] = full_box
results = []
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
for i, pred in enumerate(p):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
if not len(pred): # save empty boxes
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
continue
if self.args.retina_masks:
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
else:
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
if not isinstance(orig_imgs, torch.Tensor):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
results.append(
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
return results

417
thirdparty/fastsam/prompt.py vendored Normal file
View File

@ -0,0 +1,417 @@
import os
import sys
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
try:
import clip # for linear_assignment
except (ImportError, AssertionError, AttributeError):
from ultralytics.yolo.utils.checks import check_requirements
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
import clip
class FastSAMPrompt:
def __init__(self, img_path, results, device='cuda') -> None:
# self.img_path = img_path
self.device = device
self.results = results
self.img_path = img_path
self.ori_img = cv2.imread(img_path)
def _segment_image(self, image, bbox):
image_array = np.array(image)
segmented_image_array = np.zeros_like(image_array)
x1, y1, x2, y2 = bbox
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
segmented_image = Image.fromarray(segmented_image_array)
black_image = Image.new('RGB', image.size, (255, 255, 255))
# transparency_mask = np.zeros_like((), dtype=np.uint8)
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
transparency_mask[y1:y2, x1:x2] = 255
transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
black_image.paste(segmented_image, mask=transparency_mask_image)
return black_image
def _format_results(self, result, filter=0):
annotations = []
n = len(result.masks.data)
for i in range(n):
annotation = {}
mask = result.masks.data[i] == 1.0
if torch.sum(mask) < filter:
continue
annotation['id'] = i
annotation['segmentation'] = mask.cpu().numpy()
annotation['bbox'] = result.boxes.data[i]
annotation['score'] = result.boxes.conf[i]
annotation['area'] = annotation['segmentation'].sum()
annotations.append(annotation)
return annotations
def filter_masks(annotations): # filte the overlap mask
annotations.sort(key=lambda x: x['area'], reverse=True)
to_remove = set()
for i in range(0, len(annotations)):
a = annotations[i]
for j in range(i + 1, len(annotations)):
b = annotations[j]
if i != j and j not in to_remove:
# check if
if b['area'] < a['area']:
if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
to_remove.add(j)
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
def _get_bbox_from_mask(self, mask):
mask = mask.astype(np.uint8)
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
x1, y1, w, h = cv2.boundingRect(contours[0])
x2, y2 = x1 + w, y1 + h
if len(contours) > 1:
for b in contours:
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
# Merge multiple bounding boxes into one.
x1 = min(x1, x_t)
y1 = min(y1, y_t)
x2 = max(x2, x_t + w_t)
y2 = max(y2, y_t + h_t)
h = y2 - y1
w = x2 - x1
return [x1, y1, x2, y2]
def plot(self,
annotations,
output,
bboxes=None,
points=None,
point_label=None,
mask_random_color=True,
better_quality=True,
retina=False,
withContours=True):
if isinstance(annotations[0], dict):
annotations = [annotation['segmentation'] for annotation in annotations]
result_name = os.path.basename(self.img_path)
image = self.ori_img
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_h = image.shape[0]
original_w = image.shape[1]
if sys.platform == "darwin":
plt.switch_backend("TkAgg")
plt.figure(figsize=(original_w / 100, original_h / 100))
# Add subplot with no margin.
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(image)
if better_quality:
if isinstance(annotations[0], torch.Tensor):
annotations = np.array(annotations.cpu())
for i, mask in enumerate(annotations):
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
if self.device == 'cpu':
annotations = np.array(annotations)
self.fast_show_mask(
annotations,
plt.gca(),
random_color=mask_random_color,
bboxes=bboxes,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
else:
if isinstance(annotations[0], np.ndarray):
annotations = torch.from_numpy(annotations)
self.fast_show_mask_gpu(
annotations,
plt.gca(),
random_color=mask_random_color,
bboxes=bboxes,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
if isinstance(annotations, torch.Tensor):
annotations = annotations.cpu().numpy()
if withContours:
contour_all = []
temp = np.zeros((original_h, original_w, 1))
for i, mask in enumerate(annotations):
if type(mask) == dict:
mask = mask['segmentation']
annotation = mask.astype(np.uint8)
if not retina:
annotation = cv2.resize(
annotation,
(original_w, original_h),
interpolation=cv2.INTER_NEAREST,
)
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
contour_all.append(contour)
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask)
save_path = output
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.axis('off')
fig = plt.gcf()
plt.draw()
try:
buf = fig.canvas.tostring_rgb()
except AttributeError:
fig.canvas.draw()
buf = fig.canvas.tostring_rgb()
cols, rows = fig.canvas.get_width_height()
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
# CPU post process
def fast_show_mask(
self,
annotation,
ax,
random_color=False,
bboxes=None,
points=None,
pointlabel=None,
retinamask=True,
target_height=960,
target_width=960,
):
msak_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
#Sort annotations based on area.
areas = np.sum(annotation, axis=(1, 2))
sorted_indices = np.argsort(areas)
annotation = annotation[sorted_indices]
index = (annotation != 0).argmax(axis=0)
if random_color:
color = np.random.random((msak_sum, 1, 1, 3))
else:
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
visual = np.concatenate([color, transparency], axis=-1)
mask_image = np.expand_dims(annotation, -1) * visual
show = np.zeros((height, weight, 4))
h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
# Use vectorized indexing to update the values of 'show'.
show[h_indices, w_indices, :] = mask_image[indices]
if bboxes is not None:
for bbox in bboxes:
x1, y1, x2, y2 = bbox
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
# draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
s=20,
c='y',
)
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
s=20,
c='m',
)
if not retinamask:
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
ax.imshow(show)
def fast_show_mask_gpu(
self,
annotation,
ax,
random_color=False,
bboxes=None,
points=None,
pointlabel=None,
retinamask=True,
target_height=960,
target_width=960,
):
msak_sum = annotation.shape[0]
height = annotation.shape[1]
weight = annotation.shape[2]
areas = torch.sum(annotation, dim=(1, 2))
sorted_indices = torch.argsort(areas, descending=False)
annotation = annotation[sorted_indices]
# Find the index of the first non-zero value at each position.
index = (annotation != 0).to(torch.long).argmax(dim=0)
if random_color:
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
else:
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([
30 / 255, 144 / 255, 255 / 255]).to(annotation.device)
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
visual = torch.cat([color, transparency], dim=-1)
mask_image = torch.unsqueeze(annotation, -1) * visual
# Select data according to the index. The index indicates which batch's data to choose at each position, converting the mask_image into a single batch form.
show = torch.zeros((height, weight, 4)).to(annotation.device)
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij')
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
# Use vectorized indexing to update the values of 'show'.
show[h_indices, w_indices, :] = mask_image[indices]
show_cpu = show.cpu().numpy()
if bboxes is not None:
for bbox in bboxes:
x1, y1, x2, y2 = bbox
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
# draw point
if points is not None:
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
s=20,
c='y',
)
plt.scatter(
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
s=20,
c='m',
)
if not retinamask:
show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
ax.imshow(show_cpu)
# clip
@torch.no_grad()
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
preprocessed_images = [preprocess(image).to(device) for image in elements]
tokenized_text = clip.tokenize([search_text]).to(device)
stacked_images = torch.stack(preprocessed_images)
image_features = model.encode_image(stacked_images)
text_features = model.encode_text(tokenized_text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
probs = 100.0 * image_features @ text_features.T
return probs[:, 0].softmax(dim=0)
def _crop_image(self, format_results):
image = Image.fromarray(cv2.cvtColor(self.ori_img, cv2.COLOR_BGR2RGB))
ori_w, ori_h = image.size
annotations = format_results
mask_h, mask_w = annotations[0]['segmentation'].shape
if ori_w != mask_w or ori_h != mask_h:
image = image.resize((mask_w, mask_h))
cropped_boxes = []
cropped_images = []
not_crop = []
filter_id = []
# annotations, _ = filter_masks(annotations)
# filter_id = list(_)
for _, mask in enumerate(annotations):
if np.sum(mask['segmentation']) <= 100:
filter_id.append(_)
continue
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
cropped_boxes.append(self._segment_image(image, bbox))
# cropped_boxes.append(segment_image(image,mask["segmentation"]))
cropped_images.append(bbox) # Save the bounding box of the cropped image.
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
def box_prompt(self, bbox=None, bboxes=None):
assert bbox or bboxes
if bboxes is None:
bboxes = [bbox]
max_iou_index = []
for bbox in bboxes:
assert (bbox[2] != 0 and bbox[3] != 0)
masks = self.results[0].masks.data
target_height = self.ori_img.shape[0]
target_width = self.ori_img.shape[1]
h = masks.shape[1]
w = masks.shape[2]
if h != target_height or w != target_width:
bbox = [
int(bbox[0] * w / target_width),
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height), ]
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
orig_masks_area = torch.sum(masks, dim=(1, 2))
union = bbox_area + orig_masks_area - masks_area
IoUs = masks_area / union
max_iou_index.append(int(torch.argmax(IoUs)))
max_iou_index = list(set(max_iou_index))
return np.array(masks[max_iou_index].cpu().numpy())
def point_prompt(self, points, pointlabel): # numpy
masks = self._format_results(self.results[0], 0)
target_height = self.ori_img.shape[0]
target_width = self.ori_img.shape[1]
h = masks[0]['segmentation'].shape[0]
w = masks[0]['segmentation'].shape[1]
if h != target_height or w != target_width:
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
onemask = np.zeros((h, w))
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
for i, annotation in enumerate(masks):
if type(annotation) == dict:
mask = annotation['segmentation']
else:
mask = annotation
for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
onemask[mask] = 1
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
onemask[mask] = 0
onemask = onemask >= 1
return np.array([onemask])
def text_prompt(self, text):
format_results = self._format_results(self.results[0], 0)
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
clip_model, preprocess = clip.load('ViT-B/32', device=self.device)
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
max_idx = scores.argsort()
max_idx = max_idx[-1]
max_idx += sum(np.array(filter_id) <= int(max_idx))
return np.array([annotations[max_idx]['segmentation']])
def everything_prompt(self):
return self.results[0].masks.data

66
thirdparty/fastsam/utils.py vendored Normal file
View File

@ -0,0 +1,66 @@
import torch
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
'''Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args:
boxes: (n, 4)
image_shape: (height, width)
threshold: pixel threshold
Returns:
adjusted_boxes: adjusted bounding boxes
'''
# Image dimensions
h, w = image_shape
# Adjust boxes
boxes[:, 0] = torch.where(boxes[:, 0] < threshold, 0, boxes[:, 0]) # x1
boxes[:, 1] = torch.where(boxes[:, 1] < threshold, 0, boxes[:, 1]) # y1
boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, w, boxes[:, 2]) # x2
boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, h, boxes[:, 3]) # y2
return boxes
def convert_box_xywh_to_xyxy(box):
x1 = box[0]
y1 = box[1]
x2 = box[0] + box[2]
y2 = box[1] + box[3]
return [x1, y1, x2, y2]
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
'''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
Args:
box1: (4, )
boxes: (n, 4)
Returns:
high_iou_indices: Indices of boxes with IoU > thres
'''
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
# obtain coordinates for intersections
x1 = torch.max(box1[0], boxes[:, 0])
y1 = torch.max(box1[1], boxes[:, 1])
x2 = torch.min(box1[2], boxes[:, 2])
y2 = torch.min(box1[3], boxes[:, 3])
# compute the area of intersection
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
# compute the area of both individual boxes
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
# compute the area of union
union = box1_area + box2_area - intersection
# compute the IoU
iou = intersection / union # Should be shape (n, )
if raw_output:
if iou.numel() == 0:
return 0
return iou
# get indices of boxes with IoU > thres
high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
return high_iou_indices

View File

@ -31,7 +31,12 @@ class SamM2M(Module):
except:
mam_url = "https://huggingface.co/conrevo/SAM4WebUI-Extension-Models/resolve/main/mam.pth"
logger.info(f"Loading mam from url: {mam_url} to path: {ckpt_path}, device: {self.m2m_device}")
state_dict = torch.hub.load_state_dict_from_url(mam_url, ckpt_path, self.m2m_device)
try:
state_dict = torch.hub.load_state_dict_from_url(mam_url, ckpt_path, self.m2m_device)
except:
error_msg = f"Unable to connect to {mam_url}, thus unable to download Matting-Anything model."
logger.error(error_msg)
raise RuntimeError(error_msg)
self.m2m.load_state_dict(state_dict)
self.m2m.eval()
@ -39,6 +44,7 @@ class SamM2M(Module):
def forward(self, features: torch.Tensor, image: torch.Tensor,
low_res_masks: torch.Tensor, masks: torch.Tensor,
ori_shape: torch.Tensor, pad_shape: torch.Tensor, guidance_mode: str):
logger.info("Applying Matting-Anything.")
self.m2m.to(self.m2m_device)
pred = self.m2m(features, image, low_res_masks)
alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred['alpha_os1'], pred['alpha_os4'], pred['alpha_os8']

View File