sd-webui-segment-anything/sam_utils/segment.py

448 lines
24 KiB
Python

from typing import List
import gc
import os
import torch
import numpy as np
from ultralytics import YOLO
from thirdparty.fastsam import FastSAM, FastSAMPrompt
from modules import shared
from modules.safe import unsafe_torch_load, load
from modules.devices import get_device_for, cpu, torch_gc
from modules.paths import models_path
from scripts.sam_state import sam_extension_dir
from sam_utils.logger import logger
from sam_utils.util import ModelInfo
from thirdparty.sam_hq.build_sam_hq import sam_model_registry
from thirdparty.sam_hq.predictor import SamPredictorHQ
from thirdparty.mam.m2m import SamM2M
class Segmentation:
"""Segmentation related process."""
def __init__(self) -> None:
"""Initialize segmentation related process."""
self.sam_model_info = {
"sam_vit_h_4b8939.pth" : ModelInfo("SAM", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "Meta", "2.56GB"),
"sam_vit_l_0b3195.pth" : ModelInfo("SAM", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "Meta", "1.25GB"),
"sam_vit_b_01ec64.pth" : ModelInfo("SAM", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", "Meta", "375MB"),
"sam_hq_vit_h.pth" : ModelInfo("SAM-HQ", "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth", "SysCV", "2.57GB"),
"sam_hq_vit_l.pth" : ModelInfo("SAM-HQ", "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth", "SysCV", "1.25GB"),
"sam_hq_vit_b.pth" : ModelInfo("SAM-HQ", "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth", "SysCV", "379MB"),
"mobile_sam.pt" : ModelInfo("SAM", "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt", "KHU", "39MB"),
"FastSAM-x.pt" : ModelInfo("SAM-YOLO", "https://huggingface.co/conrevo/SAM4WebUI-Extension-Models/resolve/main/FastSAM-x.pt", "CASIA-IVA-Lab", "138MB"),
}
self.check_model_availability()
self.sam_model = None
self.sam_model_name = ""
self.sam_model_wrapper = None
self.sam_device = get_device_for("sam")
self.sam_m2m = SamM2M()
def check_model_availability(self) -> List[str]:
"""retrieve all models in all the model directories
Returns:
List[str]: Model information displayed on the UI
"""
user_sam_model_dir = shared.opts.data.get("sam_model_path", "")
sd_sam_model_dir = os.path.join(models_path, "sam")
scripts_sam_model_dir = os.path.join(sam_extension_dir, "models/sam")
sd_yolo_model_dir = os.path.join(models_path, "ultralytics")
sam_model_dirs = [sd_sam_model_dir, scripts_sam_model_dir]
if user_sam_model_dir != "":
sam_model_dirs.append(user_sam_model_dir)
if shared.opts.data.get("sam_use_yolo_models", False):
sam_model_dirs.append(sd_yolo_model_dir)
for dir in sam_model_dirs:
if os.path.isdir(dir):
# for each model inside sam_model_info, update its information
sam_model_names = [name for name in os.listdir(dir) if (".pth" in name or ".pt" in name)]
for name in sam_model_names:
if name in self.sam_model_info.keys():
self.sam_model_info[name].local_path(os.path.join(dir, name))
elif shared.opts.data.get("sam_use_yolo_models", False):
logger.warn(f"Model {name} not found in support list, default to use YOLO as initializer")
self.sam_model_info[name] = ModelInfo("YOLO", os.path.join(dir, name), "?", "?", "downloaded")
return [val.get_info(key) for key, val in self.sam_model_info.items()]
def load_sam_model(self, sam_checkpoint_name: str) -> None:
"""Load segmentation model.
Args:
sam_checkpoint_name (str): The model filename. Do not change.
Raises:
RuntimeError: Model file not found in either support list or local model directory.
RuntimeError: Cannot automatically download model from remote server.
"""
sam_checkpoint_name = sam_checkpoint_name.split(" ")[0]
if self.sam_model is None or self.sam_model_name != sam_checkpoint_name:
if sam_checkpoint_name not in self.sam_model_info.keys():
error_msg = f"{sam_checkpoint_name} not found and cannot be auto downloaded"
raise RuntimeError(error_msg)
if "http" in self.sam_model_info[sam_checkpoint_name].url:
sam_url = self.sam_model_info[sam_checkpoint_name].url
user_dir = shared.opts.data.get("sam_model_path", "")
sd_dir = os.path.join(models_path, "sam")
scripts_dir = os.path.join(sam_extension_dir, "models/sam")
sam_model_dir = user_dir if user_dir != "" else (sd_dir if os.path.isdir(sd_dir) else scripts_dir)
logger.info(f"Downloading segmentation model {sam_checkpoint_name} from {sam_url} to {sam_model_dir}")
try:
torch.hub.download_url_to_file(sam_url, sam_model_dir)
except:
error_msg = f"Cannot connect to {sam_url}. Set terminal proxy or download via browser to {sam_model_dir}"
logger.error(f"{error_msg}")
raise Exception(error_msg)
device_name = "GPU" if "cuda" in str(self.sam_device).lower() else "CPU"
logger.info(f"Initializing {sam_checkpoint_name} to {device_name}")
model_type = self.sam_model_info[sam_checkpoint_name].model_type
model_path = self.sam_model_info[sam_checkpoint_name].url
torch.load = unsafe_torch_load
if "YOLO" not in model_type:
logger.info(f"Loading SAM model from {model_path}")
self.sam_model = sam_model_registry[sam_checkpoint_name](checkpoint=model_path)
self.sam_model_wrapper = SamPredictorHQ(self.sam_model, 'HQ' in model_type)
elif "SAM" in model_type:
logger.info(f"Loading FastSAM model from {model_path}")
self.sam_model = FastSAM(model_path)
self.sam_model_wrapper = self.sam_model
elif shared.opts.data.get("sam_use_yolo_models", False):
logger.info(f"Loading YOLO model from {model_path}")
self.sam_model = YOLO(model_path)
self.sam_model_wrapper = self.sam_model
else:
error_msg = f"Unsupported model type {model_type}"
raise RuntimeError(error_msg)
self.sam_model_name = sam_checkpoint_name
torch.load = load
self.sam_model.to(self.sam_device)
def change_device(self, use_cpu: bool) -> None:
"""Change the device of the segmentation model.
Args:
use_cpu (bool): Whether to use CPU for SAM inference.
"""
self.sam_device = cpu if use_cpu else get_device_for("sam")
def sam_predict(self,
input_image: np.ndarray,
positive_points: List[List[int]]=None,
negative_points: List[List[int]]=None,
positive_bbox: List[List[float]]=None,
negative_bbox: List[List[float]]=None,
merge_positive=True,
multimask_output=True,
point_with_box=False,
use_mam=False,
mam_guidance_mode: str="mask") -> np.ndarray:
"""Run segmentation inference with models based on segment anything.
Args:
input_image (np.ndarray): input image, expect shape HW3.
positive_points (List[List[int]], optional): positive point prompts, expect N * xy. Defaults to None.
negative_points (List[List[int]], optional): negative point prompts, expect N * xy. Defaults to None.
positive_bbox (List[List[float]], optional): positive bbox prompts, expect N * xyxy. Defaults to None.
negative_bbox (List[List[float]], optional): negative bbox prompts, expect N * xyxy. Defaults to None.
merge_positive (bool, optional): OR all positive masks. Defaults to True. Valid only if point_with_box is False.
multimask_output (bool, optional): output 3 masks or not. Defaults to True.
point_with_box (bool, optional): always send bboxes and points to the model at the same time. Defaults to False.
use_mam (bool, optional): use Matting-Anything. Defaults to False.
mam_guidance_mode (str, optional): guidance model for Matting-Anything. Expect "mask" or "bbox". Defaults to "mask". Valid only if use_mam is True.
Returns:
np.ndarray: mask output, expect shape 11HW or 31HW.
"""
has_points = positive_points is not None or negative_points is not None
has_bbox = positive_bbox is not None or negative_bbox is not None
assert has_points or has_bbox, "No input provided. Please provide at least one point or one bbox."
assert type(self.sam_model_wrapper) == SamPredictorHQ, "Incorrect SAM model wrapper. Expect SamPredictorHQ here."
mask_shape = ((3, 1) if multimask_output else (1, 1)) + input_image.shape[:2]
self.sam_model_wrapper.set_image(input_image)
# If use Matting-Anything, load mam model.
if use_mam:
try:
self.sam_m2m.load_m2m()
except:
use_mam = False
# Matting-Anything inference for each SAM inference.
def _mam_infer(mask: np.ndarray, low_res_mask: np.ndarray) -> np.ndarray:
low_res_mask_logits = low_res_mask > self.sam_model_wrapper.model.mask_threshold
if use_mam:
mask = self.sam_m2m.forward(
self.sam_model_wrapper.features, torch.tensor(input_image), low_res_mask_logits, mask,
self.sam_model_wrapper.original_size, self.sam_model_wrapper.input_size, mam_guidance_mode)
return mask
# When always send bboxes and points to SAM at the same time.
if has_points and has_bbox and point_with_box:
logger.info(f"SAM {self.sam_model_name} inference with "
f"{len(positive_points)} positive points, {len(negative_points)} negative points, "
f"{len(positive_bbox)} positive bboxes, {len(negative_bbox)} negative bboxes. "
f"For each bbox, all point prompts are affective. Masks for each bbox will be merged.")
point_coords = np.array(positive_points + negative_points)
point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
point_labels_neg = np.array([0] * len(positive_points) + [1] * len(negative_points))
def _box_infer(_bbox, _point_labels):
_masks = []
for box in _bbox:
mask, _, low_res_mask = self.sam_model_wrapper.predict(
point_coords=point_coords,
point_labels=_point_labels,
box=np.array(box),
multimask_output=multimask_output)
mask = mask[:, None, ...]
low_res_mask = low_res_mask[:, None, ...]
_masks.append(_mam_infer(mask, low_res_mask))
return np.logical_or.reduce(_masks)
mask_bbox_positive = _box_infer(positive_bbox, point_labels) if positive_bbox is not None else np.ones(shape=mask_shape, dtype=np.bool_)
mask_bbox_negative = _box_infer(negative_bbox, point_labels_neg) if negative_bbox is not None else np.zeros(shape=mask_shape, dtype=np.bool_)
return mask_bbox_positive & ~mask_bbox_negative
# When separate bbox inference from point inference.
if has_points:
logger.info(f"SAM {self.sam_model_name} inference with "
f"{len(positive_points)} positive points, {len(negative_points)} negative points")
point_coords = np.array(positive_points + negative_points)
point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
mask_points_positive, _, low_res_masks_points_positive = self.sam_model_wrapper.predict(
point_coords=point_coords,
point_labels=point_labels,
box=None,
multimask_output=multimask_output)
mask_points_positive = mask_points_positive[:, None, ...]
low_res_masks_points_positive = low_res_masks_points_positive[:, None, ...]
mask_points_positive = _mam_infer(mask_points_positive, low_res_masks_points_positive)
else:
mask_points_positive = np.ones(shape=mask_shape, dtype=np.bool_)
def _box_infer(_bbox, _character):
logger.info(f"SAM {self.sam_model_name} inference with {len(positive_bbox)} {_character} bboxes")
transformed_boxes = self.sam_model_wrapper.transform.apply_boxes_torch(torch.tensor(_bbox), input_image.shape[:2])
mask, _, low_res_mask = self.sam_model_wrapper.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes.to(self.sam_device),
multimask_output=multimask_output)
mask = mask.permute(1, 0, 2, 3).cpu().numpy()
low_res_mask = low_res_mask.permute(1, 0, 2, 3).cpu().numpy()
return _mam_infer(mask, low_res_mask)
mask_bbox_positive = _box_infer(positive_bbox, "positive") if positive_bbox is not None else np.ones(shape=mask_shape, dtype=np.bool_)
mask_bbox_negative = _box_infer(negative_bbox, "negative") if negative_bbox is not None else np.zeros(shape=mask_shape, dtype=np.bool_)
if merge_positive:
return (mask_points_positive | mask_bbox_positive) & ~mask_bbox_negative
else:
return (mask_points_positive & mask_bbox_positive) & ~mask_bbox_negative
def fastsam_predict(self,
input_image: np.ndarray,
positive_points: List[List[int]]=None,
negative_points: List[List[int]]=None,
positive_bbox: List[List[float]]=None,
negative_bbox: List[List[float]]=None,
positive_text: str="",
negative_text: str="",
merge_positive=True,
merge_negative=True,
conf=0.4, iou=0.9,) -> np.ndarray:
"""Run segmentation inference with models based on FastSAM. (This is a special kind of YOLO model)
Args:
input_image (np.ndarray): input image, expect shape HW3.
positive_points (List[List[int]], optional): positive point prompts, expect N * xy Defaults to None.
negative_points (List[List[int]], optional): negative point prompts, expect N * xy Defaults to None.
positive_bbox (List[List[float]], optional): positive bbox prompts, expect N * xyxy. Defaults to None.
negative_bbox (List[List[float]], optional): negative bbox prompts, expect N * xyxy. Defaults to None.
positive_text (str, optional): positive text prompts. Defaults to "".
negative_text (str, optional): negative text prompts. Defaults to "".
merge_positive (bool, optional): OR all positive masks. Defaults to True.
merge_negative (bool, optional): OR all negative masks. Defaults to True.
conf (float, optional): object confidence threshold. Defaults to 0.4.
iou (float, optional): iou threshold for filtering the annotations. Defaults to 0.9.
Returns:
np.ndarray: mask output, expect shape 11HW. FastSAM does not support multi-mask selection.
"""
assert type(self.sam_model_wrapper) == FastSAM, "Incorrect SAM model wrapper. Expect FastSAM here."
logger.info(f"Running FastSAM {self.sam_model_name} inference.")
annotation = self.sam_model_wrapper(
input_image, device=self.sam_device, retina_masks=True, imgsz=1024, conf=conf, iou=iou)
has_points = positive_points is not None or negative_points is not None
has_bbox = positive_bbox is not None or negative_bbox is not None
has_text = positive_text != "" or negative_text != ""
assert has_points or has_bbox or has_text, "No input provided. Please provide at least one point or one bbox or one text."
logger.info("Post-processing FastSAM inference.")
prompt_process = FastSAMPrompt(input_image, annotation, device=self.sam_device)
mask_shape = (1, 1) + input_image.shape[:2]
mask_bbox_positive = prompt_process.box_prompt(bboxes=positive_bbox) if positive_bbox is not None else np.ones(shape=mask_shape, dtype=np.bool_)
mask_bbox_negative = prompt_process.box_prompt(bboxes=negative_bbox) if negative_bbox is not None else np.zeros(shape=mask_shape, dtype=np.bool_)
mask_text_positive = prompt_process.text_prompt(text=positive_text) if positive_text != "" else np.ones(shape=mask_shape, dtype=np.bool_)
mask_text_negative = prompt_process.text_prompt(text=negative_text) if negative_text != "" else np.zeros(shape=mask_shape, dtype=np.bool_)
point_coords = np.array(positive_points + negative_points)
point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
mask_points_positive = prompt_process.point_prompt(points=point_coords, pointlabel=point_labels) if has_points else np.ones(shape=mask_shape, dtype=np.bool_)
if merge_positive:
mask_positive = mask_bbox_positive | mask_text_positive | mask_points_positive
else:
mask_positive = mask_bbox_positive & mask_text_positive & mask_points_positive
if merge_negative:
mask_negative = mask_bbox_negative | mask_text_negative
else:
mask_negative = mask_bbox_negative & mask_text_negative
return mask_positive & ~mask_negative
def yolo_predict(self, input_image: np.ndarray, conf=0.4) -> np.ndarray:
"""Run segmentation inference with models based on YOLO.
Args:
input_image (np.ndarray): input image, expect shape HW3.
conf (float, optional): object confidence threshold. Defaults to 0.4.
Raises:
RuntimeError: not getting any bbox. Might be caused by high conf or non-detection/segmentation model.
Returns:
np.ndarray: mask output, expect shape 11HW. YOLO does not support multi-mask selection.
"""
assert shared.opts.data.get("sam_use_yolo_models", False), "YOLO models are not enabled. Please enable in settings/Segment Anything."
assert type(self.sam_model_wrapper) == YOLO, "Incorrect SAM model wrapper. Expect YOLO here."
logger.info("Running YOLO inference.")
pred = self.sam_model_wrapper(input_image, conf=conf)
bboxes = pred[0].boxes.xyxy.cpu().numpy()
if bboxes.size == 0:
error_msg = "You are not getting any bbox. There are 2 possible reasons. "\
"1. You set up a high conf which means that you should lower the conf. "\
"2. You are using a non-detection/segmentation model which means that you should check your model type."
logger.error(error_msg)
raise RuntimeError(error_msg)
if pred[0].masks is None:
logger.warn("You are not using a segmentation model. Will use bbox to create masks.")
masks = []
for bbox in bboxes:
mask_shape = (1, 1) + input_image.shape[:2]
mask = np.zeros(mask_shape, dtype=bool)
x1, y1, x2, y2 = bbox
mask[:, :, y1:y2+1, x1:x2+1] = True
return np.logical_or.reduce(masks, axis=0)
else:
return np.logical_or.reduce(pred[0].masks.data, axis=0)
def clear(self):
"""Clear segmentation model from CPU & GPU."""
del self.sam_model
self.sam_model = None
self.sam_model_name = ""
self.sam_model_wrapper = None
self.sam_m2m.clear()
def unload_model(self):
"""Move all segmentation models to CPU."""
if self.sam_model is not None:
self.sam_model.cpu()
self.sam_m2m.unload_model()
def __call__(self,
sam_checkpoint_name: str,
input_image: np.ndarray,
positive_points: List[List[int]]=None,
negative_points: List[List[int]]=None,
positive_bbox: List[List[float]]=None,
negative_bbox: List[List[float]]=None,
positive_text: str="",
negative_text: str="",
merge_positive=True,
merge_negative=True,
multimask_output=True,
point_with_box=False,
use_mam=False,
mam_guidance_mode: str="mask",
conf=0.4, iou=0.9,) -> np.ndarray:
# use_cpu: bool=False,) -> np.ndarray:
"""Entry for segmentation inference. Load model, run inference, unload model if lowvram.
Args:
sam_checkpoint_name (str): The model filename. Do not change.
input_image (np.ndarray): input image, expect shape HW3.
positive_points (List[List[int]], optional): positive point prompts, expect N * xy. Defaults to None. Valid for SAM & FastSAM.
negative_points (List[List[int]], optional): negative point prompts, expect N * xy. Defaults to None. Valid for SAM & FastSAM.
positive_bbox (List[List[float]], optional): positive bbox prompts, expect N * xyxy. Defaults to None. Valid for SAM & FastSAM.
negative_bbox (List[List[float]], optional): negative bbox prompts, expect N * xyxy. Defaults to None. Valid for SAM & FastSAM.
positive_text (str, optional): positive text prompts. Defaults to "". Valid for FastSAM.
negative_text (str, optional): negative text prompts. Defaults to "". Valid for FastSAM.
merge_positive (bool, optional): OR all positive masks. Defaults to True. Valid for SAM (point_with_box is True) & FastSAM.
merge_negative (bool, optional): OR all negative masks. Defaults to True. Valid for FastSAM.
multimask_output (bool, optional): output 3 masks or not. Defaults to True. Valid for SAM.
point_with_box (bool, optional): always send bboxes and points to the model at the same time. Defaults to False. Valid for SAM.
use_mam (bool, optional): use Matting-Anything. Defaults to False. Valid for SAM.
mam_guidance_mode (str, optional): guidance model for Matting-Anything. Expect "mask" or "bbox". Defaults to "mask". Valid for SAM and use_mam is True.
conf (float, optional): object confidence threshold. Defaults to 0.4. Valid for FastSAM & YOLO.
iou (float, optional): iou threshold for filtering the annotations. Defaults to 0.9. Valid for FastSAM.
use_cpu (bool, optional): use CPU for SAM inference. Defaults to False.
Returns:
np.ndarray: _description_
"""
# self.change_device(use_cpu)
self.load_sam_model(sam_checkpoint_name)
if type(self.sam_model_wrapper) == SamPredictorHQ:
masks = self.sam_predict(
input_image=input_image,
positive_points=positive_points,
negative_points=negative_points,
positive_bbox=positive_bbox,
negative_bbox=negative_bbox,
merge_positive=merge_positive,
multimask_output=multimask_output,
point_with_box=point_with_box,
use_mam=use_mam,
mam_guidance_mode=mam_guidance_mode)
elif type(self.sam_model_wrapper) == FastSAM:
masks = self.fastsam_predict(
input_image=input_image,
positive_points=positive_points,
negative_points=negative_points,
positive_bbox=positive_bbox,
negative_bbox=negative_bbox,
positive_text=positive_text,
negative_text=negative_text,
merge_positive=merge_positive,
merge_negative=merge_negative,
conf=conf, iou=iou)
else:
masks = self.yolo_predict(
input_image=input_image,
conf=conf)
if shared.cmd_opts.lowvram:
self.unload_model()
gc.collect()
torch_gc()
return masks
# how to use yolo for auto
# category name dropdown and dynamic ui
# yolo model for segmentation and detection
# zoom in, unify box+point
# make masks smaller