diff --git a/sam_utils/autosam.py b/sam_utils/autosam.py new file mode 100644 index 0000000..3d8b938 --- /dev/null +++ b/sam_utils/autosam.py @@ -0,0 +1,284 @@ +from typing import List, Tuple, Union, Optional +import os +import glob +import copy +from PIL import Image +import numpy as np +import torch +from segment_anything.modeling import Sam +from thirdparty.fastsam import FastSAM +from thirdparty.sam_hq.automatic import SamAutomaticMaskGeneratorHQ +from sam_utils.segment import Segmentation +from sam_utils.logger import logger +from sam_utils.util import blend_image_and_seg + + +class AutoSAM: + """Automatic segmentation.""" + + def __init__(self, sam: Segmentation) -> None: + """AutoSAM initialization. + + Args: + sam (Segmentation): global Segmentation instance. + """ + self.sam = sam + self.auto_sam: Union[SamAutomaticMaskGeneratorHQ, FastSAM] = None + self.fastsam_conf = None + self.fastsam_iou = None + + + def auto_generate(self, img: np.ndarray) -> List[dict]: + """Generate segmentation. + + Args: + img (np.ndarray): input image. + + Returns: + List[dict]: list of segmentation masks. + """ + return self.auto_sam.generate(img) if type(self.auto_sam) == SamAutomaticMaskGeneratorHQ else \ + self.auto_sam(img, device=self.sam.sam_device, retina_masks=True, imgsz=1024, conf=self.fastsam_conf, iou=self.fastsam_iou) + + + def strengthen_semantic_seg(self, class_ids: np.ndarray, img: np.ndarray) -> np.ndarray: + # TODO: class_ids use multiple dimensions, categorical mask single and batch + logger.info("AutoSAM strengthening semantic segmentation") + from sam_utils.util import install_pycocotools + install_pycocotools() + import pycocotools.mask as maskUtils + semantc_mask = copy.deepcopy(class_ids) + annotations = self.auto_generate(img) + annotations = sorted(annotations, key=lambda x: x['area'], reverse=True) + logger.info(f"AutoSAM generated {len(annotations)} masks") + for ann in annotations: + valid_mask = torch.tensor(maskUtils.decode(ann['segmentation'])).bool() + propose_classes_ids = torch.tensor(class_ids[valid_mask]) + num_class_proposals = len(torch.unique(propose_classes_ids)) + if num_class_proposals == 1: + semantc_mask[valid_mask] = propose_classes_ids[0].numpy() + continue + top_1_propose_class_ids = torch.bincount(propose_classes_ids.flatten()).topk(1).indices + semantc_mask[valid_mask] = top_1_propose_class_ids.numpy() + logger.info("AutoSAM strengthening process end") + return semantc_mask + + + def random_segmentation(self, img: Image.Image) -> Tuple[List[Image.Image], str]: + """Random segmentation for EditAnything + + Args: + img (Image.Image): input image. + + Raises: + ModuleNotFoundError: ControlNet not installed. + + Returns: + Tuple[List[Image.Image], str]: List of 3 displayed images and output message. + """ + logger.info("AutoSAM generating random segmentation for EditAnything") + img_np = np.array(img.convert("RGB")) + annotations = self.auto_generate(img_np) + logger.info(f"AutoSAM generated {len(annotations)} masks") + H, W, _ = img_np.shape + color_map = np.zeros((H, W, 3), dtype=np.uint8) + detected_map_tmp = np.zeros((H, W), dtype=np.uint16) + for idx, annotation in enumerate(annotations): + current_seg = annotation['segmentation'] + color_map[current_seg] = np.random.randint(0, 255, (3)) + detected_map_tmp[current_seg] = idx + 1 + detected_map = np.zeros((detected_map_tmp.shape[0], detected_map_tmp.shape[1], 3)) + detected_map[:, :, 0] = detected_map_tmp % 256 + detected_map[:, :, 1] = detected_map_tmp // 256 + try: + from scripts.processor import HWC3 + except: + raise ModuleNotFoundError("ControlNet extension not found.") + detected_map = HWC3(detected_map.astype(np.uint8)) + logger.info("AutoSAM generation process end") + return [blend_image_and_seg(img_np, color_map), Image.fromarray(color_map), Image.fromarray(detected_map)], \ + "Random segmentation done. Left above (0) is blended image, right above (1) is random segmentation, left below (2) is Edit-Anything control input." + + + def layout_single_image(self, input_image: Image.Image, output_path: str) -> None: + """Single image layout generation. + + Args: + input_image (Image.Image): input image. + output_path (str): output path. + """ + img_np = np.array(input_image.convert("RGB")) + annotations = self.auto_generate(img_np) + logger.info(f"AutoSAM generated {len(annotations)} annotations") + annotations = sorted(annotations, key=lambda x: x['area']) + for idx, annotation in enumerate(annotations): + img_tmp = np.zeros((img_np.shape[0], img_np.shape[1], 3)) + img_tmp[annotation['segmentation']] = img_np[annotation['segmentation']] + img_np[annotation['segmentation']] = np.array([0, 0, 0]) + img_tmp = Image.fromarray(img_tmp.astype(np.uint8)) + img_tmp.save(os.path.join(output_path, f"{idx}.png")) + img_np = Image.fromarray(img_np.astype(np.uint8)) + img_np.save(os.path.join(output_path, "leftover.png")) + + + def layout(self, input_image_or_path: Union[str, Image.Image], output_path: str) -> str: + """Single or bath layout generation. + + Args: + input_image_or_path (Union[str, Image.Image]): input imag or path. + output_path (str): output path. + + Returns: + str: generation message. + """ + if isinstance(input_image_or_path, str): + logger.info("Image layer division batch processing") + all_files = glob.glob(os.path.join(input_image_or_path, "*")) + for image_index, input_image_file in enumerate(all_files): + logger.info(f"Processing {image_index}/{len(all_files)} {input_image_file}") + try: + input_image = Image.open(input_image_file) + output_directory = os.path.join(output_path, os.path.splitext(os.path.basename(input_image_file))[0]) + from pathlib import Path + Path(output_directory).mkdir(exist_ok=True) + except: + logger.warn(f"File {input_image_file} not image, skipped.") + continue + self.layout_single_image(input_image, output_directory) + else: + self.layout_single_image(input_image_or_path, output_path) + return "Done" + + + def semantic_segmentation(self, input_image: Image.Image, annotator_name: str, processor_res: int, + use_pixel_perfect: bool, resize_mode: int, target_W: int, target_H: int) -> Tuple[List[Image.Image], str]: + """Semantic segmentation enhanced by segment anything. + + Args: + input_image (Image.Image): input image. + annotator_name (str): annotator name. Should be one of "seg_ufade20k"|"seg_ofade20k"|"seg_ofcoco". + processor_res (int): processor resolution. Support 64-2048. + use_pixel_perfect (bool): whether to use pixel perfect written by lllyasviel. + resize_mode (int): resize mode for pixel perfect, should be 0|1|2. + target_W (int): target width for pixel perfect. + target_H (int): target height for pixel perfect. + + Raises: + ModuleNotFoundError: ControlNet not installed. + + Returns: + Tuple[List[Image.Image], str]: list of 4 displayed images and message. + """ + assert input_image is not None, "No input image." + if "seg" in annotator_name: + try: + from scripts.processor import uniformer, oneformer_coco, oneformer_ade20k + from scripts.external_code import pixel_perfect_resolution + oneformers = { + "ade20k": oneformer_ade20k, + "coco": oneformer_coco + } + except: + raise ModuleNotFoundError("ControlNet extension not found.") + input_image_np = np.array(input_image) + if use_pixel_perfect: + processor_res = pixel_perfect_resolution(input_image_np, resize_mode, target_W, target_H) + logger.info("Generating semantic segmentation without SAM") + if annotator_name == "seg_ufade20k": + original_semantic = uniformer(input_image_np, processor_res) + else: + dataset = annotator_name.split('_')[-1][2:] + original_semantic = oneformers[dataset](input_image_np, processor_res) + logger.info("Generating semantic segmentation with SAM") + sam_semantic = self.strengthen_semantic_seg(np.array(original_semantic), input_image_np) + output_gallery = [original_semantic, sam_semantic, blend_image_and_seg(input_image, original_semantic), blend_image_and_seg(input_image, sam_semantic)] + return output_gallery, "Done. Left is segmentation before SAM, right is segmentation after SAM." + else: + return self.random_segmentation(input_image) + + + def categorical_mask_image(self, annotator_name: str, processor_res: int, category_input: List[int], input_image: Image.Image, + use_pixel_perfect: bool, resize_mode: int, target_W: int, target_H: int) -> Tuple[np.ndarray, Image.Image]: + """Single image categorical mask. + + Args: + annotator_name (str): annotator name. Should be one of "seg_ufade20k"|"seg_ofade20k"|"seg_ofcoco". + processor_res (int): processor resolution. Support 64-2048. + category_input (List[int]): category input. + input_image (Image.Image): input image. + use_pixel_perfect (bool): whether to use pixel perfect written by lllyasviel. + resize_mode (int): resize mode for pixel perfect, should be 0|1|2. + target_W (int): target width for pixel perfect. + target_H (int): target height for pixel perfect. + + Raises: + ModuleNotFoundError: ControlNet not installed. + AssertionError: Illegal class id. + + Returns: + Tuple[np.ndarray, Image.Image]: mask in resized shape and resized input image. + """ + assert input_image is not None, "No input image." + try: + from scripts.processor import uniformer, oneformer_coco, oneformer_ade20k + from scripts.external_code import pixel_perfect_resolution + oneformers = { + "ade20k": oneformer_ade20k, + "coco": oneformer_coco + } + except: + raise ModuleNotFoundError("ControlNet extension not found.") + filter_classes = category_input + assert len(filter_classes) > 0, "No class selected." + try: + filter_classes = [int(i) for i in filter_classes] + except: + raise AssertionError("Illegal class id. You may have input some string.") + input_image_np = np.array(input_image) + if use_pixel_perfect: + processor_res = pixel_perfect_resolution(input_image_np, resize_mode, target_W, target_H) + crop_input_image_copy = copy.deepcopy(input_image) + logger.info(f"Generating categories with processor {annotator_name}") + if annotator_name == "seg_ufade20k": + original_semantic = uniformer(input_image_np, processor_res) + else: + dataset = annotator_name.split('_')[-1][2:] + original_semantic = oneformers[dataset](input_image_np, processor_res) + sam_semantic = self.strengthen_semantic_seg(np.array(original_semantic), input_image_np) + mask = np.zeros(sam_semantic.shape, dtype=np.bool_) + from sam_utils.config import SEMANTIC_CATEGORIES + for i in filter_classes: + mask[np.equal(sam_semantic, SEMANTIC_CATEGORIES[annotator_name][i])] = True + return mask, crop_input_image_copy + + + def register( + self, + sam_model_name: str, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + fastsam_conf: float = 0.4, + fastsam_iou: float = 0.9) -> None: + """Register AutoSAM module.""" + self.sam.load_sam_model(sam_model_name) + assert type(self.sam.sam_model) in [FastSAM, Sam], f"{sam_model_name} does not support auto segmentation." + if type(self.sam.sam_model) == FastSAM: + self.fastsam_conf = fastsam_conf + self.fastsam_iou = fastsam_iou + self.auto_sam = self.sam.sam_model + else: + self.auto_sam = SamAutomaticMaskGeneratorHQ( + self.sam.sam_model, points_per_side, points_per_batch, pred_iou_thresh, + stability_score_thresh, stability_score_offset, box_nms_thresh, + crop_n_layers, crop_nms_thresh, crop_overlap_ratio, crop_n_points_downscale_factor, None, + min_mask_region_area, output_mode) diff --git a/sam_utils/detect.py b/sam_utils/detect.py index 5653220..d509b96 100644 --- a/sam_utils/detect.py +++ b/sam_utils/detect.py @@ -1,22 +1,24 @@ -from typing import Tuple +from typing import List import os -import gc from PIL import Image import torch from modules import shared -from modules.devices import device, torch_gc -from scripts.sam_log import logger +from modules.devices import device +from sam_utils.logger import logger + -# TODO: support YOLO models class Detection: + """Detection related process. + """ def __init__(self) -> None: + """Initialize detection related process. + """ self.dino_model = None self.dino_model_type = "" from scripts.sam_state import sam_extension_dir self.dino_model_dir = os.path.join(sam_extension_dir, "models/grounding-dino") - self.dino_model_list = ["GroundingDINO_SwinT_OGC (694MB)", "GroundingDINO_SwinB (938MB)"] self.dino_model_info = { "GroundingDINO_SwinT_OGC (694MB)": { "checkpoint": "groundingdino_swint_ogc.pth", @@ -29,68 +31,26 @@ class Detection: "url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth" }, } - self.dino_install_issue_text = "Please permanently switch to local GroundingDINO on Settings/Segment Anything or submit an issue to https://github.com/IDEA-Research/Grounded-Segment-Anything/issues." - def _install_goundingdino(self) -> bool: - if shared.opts.data.get("sam_use_local_groundingdino", False): - logger.info("Using local groundingdino.") - return False + def _load_dino_model(self, dino_checkpoint: str, use_pip_dino: bool) -> None: + """Load GroundignDINO model to device. - def verify_dll(install_local=True): - try: - from groundingdino import _C - logger.info("GroundingDINO dynamic library have been successfully built.") - return True - except Exception: - import traceback - traceback.print_exc() - def run_pip_uninstall(command, desc=None): - from launch import python, run - default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1") - return run(f'"{python}" -m pip uninstall -y {command}', desc=f"Uninstalling {desc}", errdesc=f"Couldn't uninstall {desc}", live=default_command_live) - if install_local: - logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and fall back to local GroundingDINO this time. {self.dino_install_issue_text}") - run_pip_uninstall( - f"groundingdino", - f"sd-webui-segment-anything requirement: groundingdino") - else: - logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and re-try installing from GitHub source code. {self.dino_install_issue_text}") - run_pip_uninstall( - f"uninstall groundingdino", - f"sd-webui-segment-anything requirement: groundingdino") - return False - - import launch - if launch.is_installed("groundingdino"): - logger.info("Found GroundingDINO in pip. Verifying if dynamic library build success.") - if verify_dll(install_local=False): - return True - try: - launch.run_pip( - f"install git+https://github.com/IDEA-Research/GroundingDINO", - f"sd-webui-segment-anything requirement: groundingdino") - logger.info("GroundingDINO install success. Verifying if dynamic library build success.") - return verify_dll() - except Exception: - import traceback - traceback.print_exc() - logger.warn(f"GroundingDINO install failed. Will fall back to local groundingdino this time. {self.dino_install_issue_text}") - return False - - - def _load_dino_model(self, dino_checkpoint: str, dino_install_success: bool) -> torch.nn.Module: + Args: + dino_checkpoint (str): GroundingDINO checkpoint name. + use_pip_dino (bool): If True, use pip installed GroundingDINO. If False, use local GroundingDINO. + """ logger.info(f"Initializing GroundingDINO {dino_checkpoint}") if self.dino_model is None or dino_checkpoint != self.dino_model_type: self.clear() - if dino_install_success: + if use_pip_dino: from groundingdino.models import build_model from groundingdino.util.slconfig import SLConfig from groundingdino.util.utils import clean_state_dict else: - from local_groundingdino.models import build_model - from local_groundingdino.util.slconfig import SLConfig - from local_groundingdino.util.utils import clean_state_dict + from thirdparty.groundingdino.models import build_model + from thirdparty.groundingdino.util.slconfig import SLConfig + from thirdparty.groundingdino.util.utils import clean_state_dict args = SLConfig.fromfile(self.dino_model_info[dino_checkpoint]["config"]) dino = build_model(args) checkpoint = torch.hub.load_state_dict_from_url( @@ -103,11 +63,20 @@ class Detection: self.dino_model.to(device=device) - def _load_dino_image(self, image_pil: Image.Image, dino_install_success: bool) -> torch.Tensor: - if dino_install_success: + def _load_dino_image(self, image_pil: Image.Image, use_pip_dino: bool) -> torch.Tensor: + """Transform image to make the image applicable to GroundingDINO. + + Args: + image_pil (Image.Image): Input image in PIL format. + use_pip_dino (bool): If True, use pip installed GroundingDINO. If False, use local GroundingDINO. + + Returns: + torch.Tensor: Transformed image in torch.Tensor format. + """ + if use_pip_dino: import groundingdino.datasets.transforms as T else: - from local_groundingdino.datasets import transforms as T + from thirdparty.groundingdino.datasets import transforms as T transform = T.Compose( [ T.RandomResize([800], max_size=1333), @@ -119,7 +88,17 @@ class Detection: return image - def _get_grounding_output(self, image: torch.Tensor, caption: str, box_threshold: float): + def _get_grounding_output(self, image: torch.Tensor, caption: str, box_threshold: float) -> torch.Tensor: + """Inference GroundingDINO model. + + Args: + image (torch.Tensor): transformed input image. + caption (str): string caption. + box_threshold (float): bbox threshold. + + Returns: + torch.Tensor: generated bounding boxes. + """ caption = caption.lower() caption = caption.strip() if not caption.endswith("."): @@ -128,7 +107,7 @@ class Detection: with torch.no_grad(): outputs = self.dino_model(image[None], captions=[caption]) if shared.cmd_opts.lowvram: - self.dino_model.cpu() + self.unload_model() logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256) boxes = outputs["pred_boxes"][0] # (nq, 4) @@ -142,12 +121,23 @@ class Detection: return boxes_filt.cpu() - def dino_predict(self, input_image: Image.Image, dino_model_name: str, text_prompt: str, box_threshold: float) -> Tuple[torch.Tensor, bool]: - install_success = self._install_goundingdino() + def dino_predict(self, input_image: Image.Image, dino_model_name: str, text_prompt: str, box_threshold: float) -> List[List[float]]: + """Exposed API for GroundingDINO inference. + + Args: + input_image (Image.Image): input image. + dino_model_name (str): GroundingDINO model name. + text_prompt (str): string prompt. + box_threshold (float): bbox threshold. + + Returns: + List[List[float]]: generated N * xyxy bounding boxes. + """ + from sam_utils.util import install_goundingdino + install_success = install_goundingdino() logger.info("Running GroundingDINO Inference") dino_image = self._load_dino_image(input_image.convert("RGB"), install_success) self._load_dino_model(dino_model_name, install_success) - using_groundingdino = install_success or shared.opts.data.get("sam_use_local_groundingdino", False) boxes_filt = self._get_grounding_output( dino_image, text_prompt, box_threshold @@ -158,16 +148,74 @@ class Detection: boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 boxes_filt[i][2:] += boxes_filt[i][:2] - gc.collect() - torch_gc() - return boxes_filt, using_groundingdino + return boxes_filt.tolist() + + + def check_yolo_availability(self) -> List[str]: + """Check if YOLO models are available. Do not check if YOLO not enabled. + + Returns: + List[str]: available YOLO models. + """ + if shared.opts.data.get("sam_use_local_yolo", False): + from modules.paths import models_path + sd_yolo_model_dir = os.path.join(models_path, "ultralytics") + return [name for name in os.listdir(sd_yolo_model_dir) if (".pth" in name or ".pt" in name)] + else: + return [] + + + def yolo_predict(self, input_image: Image.Image, yolo_model_name: str, conf=0.4) -> List[List[float]]: + """Run detection inference with models based on YOLO. + + Args: + input_image (np.ndarray): input image, expect shape HW3. + conf (float, optional): object confidence threshold. Defaults to 0.4. + + Raises: + RuntimeError: not getting any bbox. Might be caused by high conf or non-detection/segmentation model. + + Returns: + np.ndarray: generated N * xyxy bounding boxes. + """ + from ultralytics import YOLO + assert shared.opts.data.get("sam_use_yolo_models", False), "YOLO models are not enabled. Please enable in settings/Segment Anything." + logger.info("Loading YOLO model.") + from modules.paths import models_path + sd_yolo_model_dir = os.path.join(models_path, "ultralytics") + self.dino_model = YOLO(os.path.join(sd_yolo_model_dir, yolo_model_name)).to(device) + self.dino_model_type = yolo_model_name + logger.info("Running YOLO inference.") + pred = self.dino_model(input_image, conf=conf) + bboxes = pred[0].boxes.xyxy.cpu().numpy() + if bboxes.size == 0: + error_msg = "You are not getting any bbox. There are 2 possible reasons. "\ + "1. You set up a high conf which means that you should lower the conf. "\ + "2. You are using a non-detection model which means that you should check your model type." + raise RuntimeError(error_msg) + else: + return bboxes.tolist() + + + def __call__(self, input_image: Image.Image, model_name: str, text_prompt: str, box_threshold: float, conf=0.4) -> List[List[float]]: + """Exposed API for detection inference.""" + if model_name in self.dino_model_info.keys(): + return self.dino_predict(input_image, model_name, text_prompt, box_threshold) + elif model_name in self.check_yolo_availability(): + return self.yolo_predict(input_image, model_name, conf) + else: + raise ValueError(f"Detection model {model_name} not found.") def clear(self) -> None: + """Clear detection model from any memory. + """ del self.dino_model self.dino_model = None def unload_model(self) -> None: + """Unload detection model from GPU to CPU. + """ if self.dino_model is not None: self.dino_model.cpu() diff --git a/scripts/sam_log.py b/sam_utils/logger.py similarity index 100% rename from scripts/sam_log.py rename to sam_utils/logger.py diff --git a/sam_utils/segment.py b/sam_utils/segment.py index bc8e827..f7a7a89 100644 --- a/sam_utils/segment.py +++ b/sam_utils/segment.py @@ -1,23 +1,27 @@ from typing import List +import gc import os import torch import numpy as np from ultralytics import YOLO +from thirdparty.fastsam import FastSAM, FastSAMPrompt from modules import shared from modules.safe import unsafe_torch_load, load -from modules.devices import get_device_for, cpu +from modules.devices import get_device_for, cpu, torch_gc from modules.paths import models_path from scripts.sam_state import sam_extension_dir -from scripts.sam_log import logger +from sam_utils.logger import logger from sam_utils.util import ModelInfo -from sam_hq.build_sam_hq import sam_model_registry -from sam_hq.predictor import SamPredictorHQ -from mam.m2m import SamM2M +from thirdparty.sam_hq.build_sam_hq import sam_model_registry +from thirdparty.sam_hq.predictor import SamPredictorHQ +from thirdparty.mam.m2m import SamM2M class Segmentation: + """Segmentation related process.""" def __init__(self) -> None: + """Initialize segmentation related process.""" self.sam_model_info = { "sam_vit_h_4b8939.pth" : ModelInfo("SAM", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "Meta", "2.56GB"), "sam_vit_l_0b3195.pth" : ModelInfo("SAM", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "Meta", "1.25GB"), @@ -37,7 +41,11 @@ class Segmentation: def check_model_availability(self) -> List[str]: - # retrieve all models in all the model directories + """retrieve all models in all the model directories + + Returns: + List[str]: Model information displayed on the UI + """ user_sam_model_dir = shared.opts.data.get("sam_model_path", "") sd_sam_model_dir = os.path.join(models_path, "sam") scripts_sam_model_dir = os.path.join(sam_extension_dir, "models/sam") @@ -45,7 +53,7 @@ class Segmentation: sam_model_dirs = [sd_sam_model_dir, scripts_sam_model_dir] if user_sam_model_dir != "": sam_model_dirs.append(user_sam_model_dir) - if shared.opts.data.get("use_yolo_models", False): + if shared.opts.data.get("sam_use_yolo_models", False): sam_model_dirs.append(sd_yolo_model_dir) for dir in sam_model_dirs: if os.path.isdir(dir): @@ -54,19 +62,27 @@ class Segmentation: for name in sam_model_names: if name in self.sam_model_info.keys(): self.sam_model_info[name].local_path(os.path.join(dir, name)) - elif shared.opts.data.get("use_yolo_models", False): + elif shared.opts.data.get("sam_use_yolo_models", False): logger.warn(f"Model {name} not found in support list, default to use YOLO as initializer") self.sam_model_info[name] = ModelInfo("YOLO", os.path.join(dir, name), "?", "?", "downloaded") return [val.get_info(key) for key, val in self.sam_model_info.items()] def load_sam_model(self, sam_checkpoint_name: str) -> None: + """Load segmentation model. + + Args: + sam_checkpoint_name (str): The model filename. Do not change. + + Raises: + RuntimeError: Model file not found in either support list or local model directory. + RuntimeError: Cannot automatically download model from remote server. + """ sam_checkpoint_name = sam_checkpoint_name.split(" ")[0] if self.sam_model is None or self.sam_model_name != sam_checkpoint_name: if sam_checkpoint_name not in self.sam_model_info.keys(): error_msg = f"{sam_checkpoint_name} not found and cannot be auto downloaded" - logger.error(f"{error_msg}") - raise Exception(error_msg) + raise RuntimeError(error_msg) if "http" in self.sam_model_info[sam_checkpoint_name].url: sam_url = self.sam_model_info[sam_checkpoint_name].url user_dir = shared.opts.data.get("sam_model_path", "") @@ -89,158 +105,343 @@ class Segmentation: logger.info(f"Loading SAM model from {model_path}") self.sam_model = sam_model_registry[sam_checkpoint_name](checkpoint=model_path) self.sam_model_wrapper = SamPredictorHQ(self.sam_model, 'HQ' in model_type) - else: + elif "SAM" in model_type: + logger.info(f"Loading FastSAM model from {model_path}") + self.sam_model = FastSAM(model_path) + self.sam_model_wrapper = self.sam_model + elif shared.opts.data.get("sam_use_yolo_models", False): logger.info(f"Loading YOLO model from {model_path}") self.sam_model = YOLO(model_path) self.sam_model_wrapper = self.sam_model + else: + error_msg = f"Unsupported model type {model_type}" + raise RuntimeError(error_msg) self.sam_model_name = sam_checkpoint_name torch.load = load self.sam_model.to(self.sam_device) def change_device(self, use_cpu: bool) -> None: + """Change the device of the segmentation model. + + Args: + use_cpu (bool): Whether to use CPU for SAM inference. + """ self.sam_device = cpu if use_cpu else get_device_for("sam") - def __call__(self, - input_image: np.ndarray, - point_coords: List[List[int]]=None, - point_labels: List[List[int]]=None, - boxes: torch.Tensor=None, - multimask_output=True, - use_mam=False): - pass - - def sam_predict(self, - input_image: np.ndarray, - positive_points: List[List[int]]=None, - negative_points: List[List[int]]=None, - boxes_coords: torch.Tensor=None, - boxes_labels: List[bool]=None, - multimask_output=True, - merge_point_and_box=True, - point_with_box=False, - use_mam=False, - use_mam_for_each_infer=False, - mam_guidance_mode: str="mask") -> np.ndarray: + input_image: np.ndarray, + positive_points: List[List[int]]=None, + negative_points: List[List[int]]=None, + positive_bbox: List[List[float]]=None, + negative_bbox: List[List[float]]=None, + merge_positive=True, + multimask_output=True, + point_with_box=False, + use_mam=False, + mam_guidance_mode: str="mask") -> np.ndarray: """Run segmentation inference with models based on segment anything. Args: input_image (np.ndarray): input image, expect shape HW3. - positive_points (List[List[int]], optional): positive point prompts. Defaults to None. - negative_points (List[List[int]], optional): negative point prompts. Defaults to None. - boxes_coords (torch.Tensor, optional): bbox inputs, expect shape xyxy. Defaults to None. - boxes_labels (List[bool], optional): bbox labels, support positive & negative bboxes. Defaults to None. + positive_points (List[List[int]], optional): positive point prompts, expect N * xy. Defaults to None. + negative_points (List[List[int]], optional): negative point prompts, expect N * xy. Defaults to None. + positive_bbox (List[List[float]], optional): positive bbox prompts, expect N * xyxy. Defaults to None. + negative_bbox (List[List[float]], optional): negative bbox prompts, expect N * xyxy. Defaults to None. + merge_positive (bool, optional): OR all positive masks. Defaults to True. Valid only if point_with_box is False. multimask_output (bool, optional): output 3 masks or not. Defaults to True. - merge_point_and_box (bool, optional): if True, output point masks || bbox masks; otherwise, output point masks && bbox masks. Valid only if point_with_box is False. Defaults to True. point_with_box (bool, optional): always send bboxes and points to the model at the same time. Defaults to False. use_mam (bool, optional): use Matting-Anything. Defaults to False. - use_mam_for_each_infer (bool, optional): use Matting-Anything for each SAM inference. Valid only if use_mam is True. Defaults to False. - mam_guidance_mode (str, optional): guidance model for Matting-Anything. Expect "mask" or "bbox". Valid only if use_mam is True. Defaults to "mask". + mam_guidance_mode (str, optional): guidance model for Matting-Anything. Expect "mask" or "bbox". Defaults to "mask". Valid only if use_mam is True. Returns: np.ndarray: mask output, expect shape 11HW or 31HW. """ - masks_for_points, masks_for_boxes, masks = None, None, None + has_points = positive_points is not None or negative_points is not None + has_bbox = positive_bbox is not None or negative_bbox is not None + assert has_points or has_bbox, "No input provided. Please provide at least one point or one bbox." + assert type(self.sam_model_wrapper) == SamPredictorHQ, "Incorrect SAM model wrapper. Expect SamPredictorHQ here." + mask_shape = ((3, 1) if multimask_output else (1, 1)) + input_image.shape[:2] self.sam_model_wrapper.set_image(input_image) + + # If use Matting-Anything, load mam model. if use_mam: - self.sam_m2m.load_m2m() # TODO: raise exception if network problem - # When always send bboxes and points to the model at the same time. - if point_coords is not None and point_labels is not None and boxes_coords is not None and point_with_box: + try: + self.sam_m2m.load_m2m() + except: + use_mam = False + + # Matting-Anything inference for each SAM inference. + def _mam_infer(mask: np.ndarray, low_res_mask: np.ndarray) -> np.ndarray: + low_res_mask_logits = low_res_mask > self.sam_model_wrapper.model.mask_threshold + if use_mam: + mask = self.sam_m2m.forward( + self.sam_model_wrapper.features, torch.tensor(input_image), low_res_mask_logits, mask, + self.sam_model_wrapper.original_size, self.sam_model_wrapper.input_size, mam_guidance_mode) + return mask + + + # When always send bboxes and points to SAM at the same time. + if has_points and has_bbox and point_with_box: + logger.info(f"SAM {self.sam_model_name} inference with " + f"{len(positive_points)} positive points, {len(negative_points)} negative points, " + f"{len(positive_bbox)} positive bboxes, {len(negative_bbox)} negative bboxes. " + f"For each bbox, all point prompts are affective. Masks for each bbox will be merged.") point_coords = np.array(positive_points + negative_points) point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points)) point_labels_neg = np.array([0] * len(positive_points) + [1] * len(negative_points)) - positive_masks, negative_masks, positive_low_res_masks, negative_low_res_masks = [], [], [], [] - # Inference for each positive bbox. - for box in boxes_coords[boxes_labels]: - mask, _, low_res_mask = self.sam_model_wrapper.predict( - point_coords=point_coords, - point_labels=point_labels, - box=box.numpy(), - multimask_output=multimask_output) - mask = mask[:, None, ...] - low_res_mask = low_res_mask[:, None, ...] - if use_mam: - low_res_mask_logits = low_res_mask > self.sam_model_wrapper.model.mask_threshold - if use_mam_for_each_infer: - mask = self.sam_m2m.forward( - self.sam_model_wrapper.features, torch.tensor(input_image), low_res_mask_logits, mask, - self.sam_model_wrapper.original_size, self.sam_model_wrapper.input_size, mam_guidance_mode) - else: - positive_low_res_masks.append(low_res_mask_logits) - positive_masks.append(mask) - positive_masks = np.logical_or(np.stack(positive_masks, 0)) - # Inference for each negative bbox. - for box in boxes_coords[[not i for i in boxes_labels]]: - mask, _, low_res_mask = self.sam_model_wrapper.predict( - point_coords=point_coords, - point_labels=point_labels_neg, - box=box.numpy(), - multimask_output=multimask_output) - mask = mask[:, None, ...] - low_res_mask = low_res_mask[:, None, ...] - if use_mam: - low_res_mask_logits = low_res_mask > self.sam_model_wrapper.model.mask_threshold - if use_mam_for_each_infer: - mask = self.sam_m2m.forward( - self.sam_model_wrapper.features, torch.tensor(input_image), low_res_mask_logits, mask, - self.sam_model_wrapper.original_size, self.sam_model_wrapper.input_size, mam_guidance_mode) - else: - negative_low_res_masks.append(low_res_mask_logits) - negative_masks.append(mask) - negative_masks = np.logical_or(np.stack(negative_masks, 0)) - masks = np.logical_and(positive_masks, ~negative_masks) - # Matting-Anything inference if not for each inference. - if use_mam and not use_mam_for_each_infer: - positive_low_res_masks = np.logical_or(np.stack(positive_low_res_masks, 0)) - negative_low_res_masks = np.logical_or(np.stack(negative_low_res_masks, 0)) - low_res_masks = np.logical_and(positive_low_res_masks, ~negative_low_res_masks) - masks = self.sam_m2m.forward( - self.sam_model_wrapper.features, torch.tensor(input_image), low_res_masks, masks, - self.sam_model_wrapper.original_size, self.sam_model_wrapper.input_size, mam_guidance_mode) - return masks - # When separate bbox inference from point inference. - if point_coords is not None and point_labels is not None: + def _box_infer(_bbox, _point_labels): + _masks = [] + for box in _bbox: + mask, _, low_res_mask = self.sam_model_wrapper.predict( + point_coords=point_coords, + point_labels=_point_labels, + box=np.array(box), + multimask_output=multimask_output) + mask = mask[:, None, ...] + low_res_mask = low_res_mask[:, None, ...] + _masks.append(_mam_infer(mask, low_res_mask)) + return np.logical_or.reduce(_masks) + + mask_bbox_positive = _box_infer(positive_bbox, point_labels) if positive_bbox is not None else np.ones(shape=mask_shape, dtype=np.bool_) + mask_bbox_negative = _box_infer(negative_bbox, point_labels_neg) if negative_bbox is not None else np.zeros(shape=mask_shape, dtype=np.bool_) + + return mask_bbox_positive & ~mask_bbox_negative + + # When separate bbox inference from point inference. + if has_points: + logger.info(f"SAM {self.sam_model_name} inference with " + f"{len(positive_points)} positive points, {len(negative_points)} negative points") point_coords = np.array(positive_points + negative_points) point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points)) - masks_for_points, _, _ = self.sam_model_wrapper.predict( + mask_points_positive, _, low_res_masks_points_positive = self.sam_model_wrapper.predict( point_coords=point_coords, point_labels=point_labels, box=None, multimask_output=multimask_output) - masks_for_points = masks_for_points[:, None, ...] - # TODO: m2m - if boxes_coords is not None: - transformed_boxes = self.sam_model_wrapper.transform.apply_boxes_torch(boxes_coords, input_image.shape[:2]) - masks_for_boxes, _, _ = self.sam_model_wrapper.predict_torch( + mask_points_positive = mask_points_positive[:, None, ...] + low_res_masks_points_positive = low_res_masks_points_positive[:, None, ...] + mask_points_positive = _mam_infer(mask_points_positive, low_res_masks_points_positive) + else: + mask_points_positive = np.ones(shape=mask_shape, dtype=np.bool_) + + def _box_infer(_bbox, _character): + logger.info(f"SAM {self.sam_model_name} inference with {len(positive_bbox)} {_character} bboxes") + transformed_boxes = self.sam_model_wrapper.transform.apply_boxes_torch(torch.tensor(_bbox), input_image.shape[:2]) + mask, _, low_res_mask = self.sam_model_wrapper.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes.to(self.sam_device), multimask_output=multimask_output) - masks_for_boxes = masks_for_boxes.permute(1, 0, 2, 3).cpu().numpy() - # TODO: m2m - if masks_for_boxes is not None and masks_for_points is not None: - if merge_point_and_box: - masks = np.logical_or(masks_for_points, masks_for_boxes) - else: - masks = np.logical_and(masks_for_points, masks_for_boxes) - elif masks_for_boxes is not None: - masks = masks_for_boxes - elif masks_for_points is not None: - masks = masks_for_points - # TODO: m2m - return masks + mask = mask.permute(1, 0, 2, 3).cpu().numpy() + low_res_mask = low_res_mask.permute(1, 0, 2, 3).cpu().numpy() + return _mam_infer(mask, low_res_mask) + mask_bbox_positive = _box_infer(positive_bbox, "positive") if positive_bbox is not None else np.ones(shape=mask_shape, dtype=np.bool_) + mask_bbox_negative = _box_infer(negative_bbox, "negative") if negative_bbox is not None else np.zeros(shape=mask_shape, dtype=np.bool_) + + if merge_positive: + return (mask_points_positive | mask_bbox_positive) & ~mask_bbox_negative + else: + return (mask_points_positive & mask_bbox_positive) & ~mask_bbox_negative - def yolo_predict(): - pass + def fastsam_predict(self, + input_image: np.ndarray, + positive_points: List[List[int]]=None, + negative_points: List[List[int]]=None, + positive_bbox: List[List[float]]=None, + negative_bbox: List[List[float]]=None, + positive_text: str="", + negative_text: str="", + merge_positive=True, + merge_negative=True, + conf=0.4, iou=0.9,) -> np.ndarray: + """Run segmentation inference with models based on FastSAM. (This is a special kind of YOLO model) + + Args: + input_image (np.ndarray): input image, expect shape HW3. + positive_points (List[List[int]], optional): positive point prompts, expect N * xy Defaults to None. + negative_points (List[List[int]], optional): negative point prompts, expect N * xy Defaults to None. + positive_bbox (List[List[float]], optional): positive bbox prompts, expect N * xyxy. Defaults to None. + negative_bbox (List[List[float]], optional): negative bbox prompts, expect N * xyxy. Defaults to None. + positive_text (str, optional): positive text prompts. Defaults to "". + negative_text (str, optional): negative text prompts. Defaults to "". + merge_positive (bool, optional): OR all positive masks. Defaults to True. + merge_negative (bool, optional): OR all negative masks. Defaults to True. + conf (float, optional): object confidence threshold. Defaults to 0.4. + iou (float, optional): iou threshold for filtering the annotations. Defaults to 0.9. + + Returns: + np.ndarray: mask output, expect shape 11HW. FastSAM does not support multi-mask selection. + """ + assert type(self.sam_model_wrapper) == FastSAM, "Incorrect SAM model wrapper. Expect FastSAM here." + logger.info(f"Running FastSAM {self.sam_model_name} inference.") + annotation = self.sam_model_wrapper( + input_image, device=self.sam_device, retina_masks=True, imgsz=1024, conf=conf, iou=iou) + has_points = positive_points is not None or negative_points is not None + has_bbox = positive_bbox is not None or negative_bbox is not None + has_text = positive_text != "" or negative_text != "" + assert has_points or has_bbox or has_text, "No input provided. Please provide at least one point or one bbox or one text." + logger.info("Post-processing FastSAM inference.") + prompt_process = FastSAMPrompt(input_image, annotation, device=self.sam_device) + mask_shape = (1, 1) + input_image.shape[:2] + mask_bbox_positive = prompt_process.box_prompt(bboxes=positive_bbox) if positive_bbox is not None else np.ones(shape=mask_shape, dtype=np.bool_) + mask_bbox_negative = prompt_process.box_prompt(bboxes=negative_bbox) if negative_bbox is not None else np.zeros(shape=mask_shape, dtype=np.bool_) + + mask_text_positive = prompt_process.text_prompt(text=positive_text) if positive_text != "" else np.ones(shape=mask_shape, dtype=np.bool_) + mask_text_negative = prompt_process.text_prompt(text=negative_text) if negative_text != "" else np.zeros(shape=mask_shape, dtype=np.bool_) + + point_coords = np.array(positive_points + negative_points) + point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points)) + mask_points_positive = prompt_process.point_prompt(points=point_coords, pointlabel=point_labels) if has_points else np.ones(shape=mask_shape, dtype=np.bool_) + + if merge_positive: + mask_positive = mask_bbox_positive | mask_text_positive | mask_points_positive + else: + mask_positive = mask_bbox_positive & mask_text_positive & mask_points_positive + if merge_negative: + mask_negative = mask_bbox_negative | mask_text_negative + else: + mask_negative = mask_bbox_negative & mask_text_negative + return mask_positive & ~mask_negative + def yolo_predict(self, input_image: np.ndarray, conf=0.4) -> np.ndarray: + """Run segmentation inference with models based on YOLO. + + Args: + input_image (np.ndarray): input image, expect shape HW3. + conf (float, optional): object confidence threshold. Defaults to 0.4. + + Raises: + RuntimeError: not getting any bbox. Might be caused by high conf or non-detection/segmentation model. + + Returns: + np.ndarray: mask output, expect shape 11HW. YOLO does not support multi-mask selection. + """ + assert shared.opts.data.get("sam_use_yolo_models", False), "YOLO models are not enabled. Please enable in settings/Segment Anything." + assert type(self.sam_model_wrapper) == YOLO, "Incorrect SAM model wrapper. Expect YOLO here." + logger.info("Running YOLO inference.") + pred = self.sam_model_wrapper(input_image, conf=conf) + bboxes = pred[0].boxes.xyxy.cpu().numpy() + if bboxes.size == 0: + error_msg = "You are not getting any bbox. There are 2 possible reasons. "\ + "1. You set up a high conf which means that you should lower the conf. "\ + "2. You are using a non-detection/segmentation model which means that you should check your model type." + logger.error(error_msg) + raise RuntimeError(error_msg) + + if pred[0].masks is None: + logger.warn("You are not using a segmentation model. Will use bbox to create masks.") + masks = [] + for bbox in bboxes: + mask_shape = (1, 1) + input_image.shape[:2] + mask = np.zeros(mask_shape, dtype=bool) + x1, y1, x2, y2 = bbox + mask[:, :, y1:y2+1, x1:x2+1] = True + return np.logical_or.reduce(masks, axis=0) + else: + return np.logical_or.reduce(pred[0].masks.data, axis=0) + + + def clear(self): + """Clear segmentation model from CPU & GPU.""" + del self.sam_model + self.sam_model = None + self.sam_model_name = "" + self.sam_model_wrapper = None + self.sam_m2m.clear() + + + def unload_model(self): + """Move all segmentation models to CPU.""" + if self.sam_model is not None: + self.sam_model.cpu() + self.sam_m2m.unload_model() + + + def __call__(self, + sam_checkpoint_name: str, + input_image: np.ndarray, + positive_points: List[List[int]]=None, + negative_points: List[List[int]]=None, + positive_bbox: List[List[float]]=None, + negative_bbox: List[List[float]]=None, + positive_text: str="", + negative_text: str="", + merge_positive=True, + merge_negative=True, + multimask_output=True, + point_with_box=False, + use_mam=False, + mam_guidance_mode: str="mask", + conf=0.4, iou=0.9,) -> np.ndarray: + # use_cpu: bool=False,) -> np.ndarray: + """Entry for segmentation inference. Load model, run inference, unload model if lowvram. + + Args: + sam_checkpoint_name (str): The model filename. Do not change. + input_image (np.ndarray): input image, expect shape HW3. + positive_points (List[List[int]], optional): positive point prompts, expect N * xy. Defaults to None. Valid for SAM & FastSAM. + negative_points (List[List[int]], optional): negative point prompts, expect N * xy. Defaults to None. Valid for SAM & FastSAM. + positive_bbox (List[List[float]], optional): positive bbox prompts, expect N * xyxy. Defaults to None. Valid for SAM & FastSAM. + negative_bbox (List[List[float]], optional): negative bbox prompts, expect N * xyxy. Defaults to None. Valid for SAM & FastSAM. + positive_text (str, optional): positive text prompts. Defaults to "". Valid for FastSAM. + negative_text (str, optional): negative text prompts. Defaults to "". Valid for FastSAM. + merge_positive (bool, optional): OR all positive masks. Defaults to True. Valid for SAM (point_with_box is True) & FastSAM. + merge_negative (bool, optional): OR all negative masks. Defaults to True. Valid for FastSAM. + multimask_output (bool, optional): output 3 masks or not. Defaults to True. Valid for SAM. + point_with_box (bool, optional): always send bboxes and points to the model at the same time. Defaults to False. Valid for SAM. + use_mam (bool, optional): use Matting-Anything. Defaults to False. Valid for SAM. + mam_guidance_mode (str, optional): guidance model for Matting-Anything. Expect "mask" or "bbox". Defaults to "mask". Valid for SAM and use_mam is True. + conf (float, optional): object confidence threshold. Defaults to 0.4. Valid for FastSAM & YOLO. + iou (float, optional): iou threshold for filtering the annotations. Defaults to 0.9. Valid for FastSAM. + use_cpu (bool, optional): use CPU for SAM inference. Defaults to False. + + Returns: + np.ndarray: _description_ + """ + # self.change_device(use_cpu) + self.load_sam_model(sam_checkpoint_name) + if type(self.sam_model_wrapper) == SamPredictorHQ: + masks = self.sam_predict( + input_image=input_image, + positive_points=positive_points, + negative_points=negative_points, + positive_bbox=positive_bbox, + negative_bbox=negative_bbox, + merge_positive=merge_positive, + multimask_output=multimask_output, + point_with_box=point_with_box, + use_mam=use_mam, + mam_guidance_mode=mam_guidance_mode) + elif type(self.sam_model_wrapper) == FastSAM: + masks = self.fastsam_predict( + input_image=input_image, + positive_points=positive_points, + negative_points=negative_points, + positive_bbox=positive_bbox, + negative_bbox=negative_bbox, + positive_text=positive_text, + negative_text=negative_text, + merge_positive=merge_positive, + merge_negative=merge_negative, + conf=conf, iou=iou) + else: + masks = self.yolo_predict( + input_image=input_image, + conf=conf) + if shared.cmd_opts.lowvram: + self.unload_model() + gc.collect() + torch_gc() + return masks + -# check device for each model type # how to use yolo for auto -# category name dropdown +# category name dropdown and dynamic ui # yolo model for segmentation and detection -# zoom in and unify box+point +# zoom in, unify box+point +# make masks smaller diff --git a/sam_utils/util.py b/sam_utils/util.py index 967df4e..629218d 100644 --- a/sam_utils/util.py +++ b/sam_utils/util.py @@ -3,7 +3,6 @@ import os import numpy as np import cv2 import copy -from scipy.ndimage import binary_dilation from PIL import Image @@ -51,6 +50,7 @@ def show_masks(image_np: np.ndarray, masks: np.ndarray, alpha=0.5) -> np.ndarray def dilate_mask(mask: np.ndarray, dilation_amt: int) -> Tuple[Image.Image, np.ndarray]: + from scipy.ndimage import binary_dilation x, y = np.meshgrid(np.arange(dilation_amt), np.arange(dilation_amt)) center = dilation_amt // 2 dilation_kernel = ((x - center)**2 + (y - center)**2 <= center**2).astype(np.uint8) @@ -110,3 +110,95 @@ def create_mask_batch_output( if save_image_with_mask: output_blend = Image.fromarray(blended_image) output_blend.save(os.path.join(dest_dir, f"{filename}_{idx}_blend{ext}")) + + +def blend_image_and_seg(image: np.ndarray, seg: np.ndarray, alpha=0.5) -> Image.Image: + image_blend = image * (1 - alpha) + np.array(seg) * alpha + return Image.fromarray(image_blend.astype(np.uint8)) + + +def install_pycocotools(): + # install pycocotools if needed + from sam_utils.logger import logger + try: + import pycocotools.mask as maskUtils + except: + logger.warn("pycocotools not found, will try installing C++ based pycocotools") + try: + from launch import run_pip + run_pip(f"install pycocotools", f"AutoSAM requirement: pycocotools") + import pycocotools.mask as maskUtils + except: + import traceback + traceback.print_exc() + import sys + if sys.platform == "win32": + logger.warn("Unable to install pycocotools, will try installing pycocotools-windows") + try: + run_pip("install pycocotools-windows", "AutoSAM requirement: pycocotools-windows") + import pycocotools.mask as maskUtils + except: + error_msg = "Unable to install pycocotools-windows" + logger.error(error_msg) + traceback.print_exc() + raise RuntimeError(error_msg) + else: + error_msg = "Unable to install pycocotools" + logger.error(error_msg) + traceback.print_exc() + raise RuntimeError(error_msg) + + +def install_goundingdino() -> bool: + """Automatically install GroundingDINO. + + Returns: + bool: False if use local GroundingDINO, True if use pip installed GroundingDINO. + """ + from sam_utils.logger import logger + from modules import shared + dino_install_issue_text = "Please permanently switch to local GroundingDINO on Settings/Segment Anything or submit an issue to https://github.com/IDEA-Research/Grounded-Segment-Anything/issues." + if shared.opts.data.get("sam_use_local_groundingdino", False): + logger.info("Using local groundingdino.") + return False + + def verify_dll(install_local=True): + try: + from groundingdino import _C + logger.info("GroundingDINO dynamic library have been successfully built.") + return True + except Exception: + import traceback + traceback.print_exc() + def run_pip_uninstall(command, desc=None): + from launch import python, run + default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1") + return run(f'"{python}" -m pip uninstall -y {command}', desc=f"Uninstalling {desc}", errdesc=f"Couldn't uninstall {desc}", live=default_command_live) + if install_local: + logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and fall back to local GroundingDINO this time. {dino_install_issue_text}") + run_pip_uninstall( + f"groundingdino", + f"sd-webui-segment-anything requirement: groundingdino") + else: + logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and re-try installing from GitHub source code. {dino_install_issue_text}") + run_pip_uninstall( + f"uninstall groundingdino", + f"sd-webui-segment-anything requirement: groundingdino") + return False + + import launch + if launch.is_installed("groundingdino"): + logger.info("Found GroundingDINO in pip. Verifying if dynamic library build success.") + if verify_dll(install_local=False): + return True + try: + launch.run_pip( + f"install git+https://github.com/IDEA-Research/GroundingDINO", + f"sd-webui-segment-anything requirement: groundingdino") + logger.info("GroundingDINO install success. Verifying if dynamic library build success.") + return verify_dll() + except Exception: + import traceback + traceback.print_exc() + logger.warn(f"GroundingDINO install failed. Will fall back to local groundingdino this time. {dino_install_issue_text}") + return False diff --git a/scripts/sam_auto.py b/scripts/sam_auto.py deleted file mode 100644 index e2afe9f..0000000 --- a/scripts/sam_auto.py +++ /dev/null @@ -1,183 +0,0 @@ -from typing import List, Tuple -import os -import glob -import copy -from PIL import Image -import numpy as np -import torch -from sam_hq.automatic import SamAutomaticMaskGeneratorHQ -from scripts.sam_log import logger - - -class AutoSAM: - - def __init__(self) -> None: - self.auto_sam: SamAutomaticMaskGeneratorHQ = None - - - def blend_image_and_seg(self, image: np.ndarray, seg: np.ndarray, alpha=0.5) -> Image.Image: - image_blend = image * (1 - alpha) + np.array(seg) * alpha - return Image.fromarray(image_blend.astype(np.uint8)) - - - def strengthen_semmantic_seg(self, class_ids: np.ndarray, img: np.ndarray) -> np.ndarray: - logger.info("AutoSAM strengthening semantic segmentation") - import pycocotools.mask as maskUtils - semantc_mask = copy.deepcopy(class_ids) - annotations = self.auto_sam.generate(img) - annotations = sorted(annotations, key=lambda x: x['area'], reverse=True) - logger.info(f"AutoSAM generated {len(annotations)} masks") - for ann in annotations: - valid_mask = torch.tensor(maskUtils.decode(ann['segmentation'])).bool() - propose_classes_ids = torch.tensor(class_ids[valid_mask]) - num_class_proposals = len(torch.unique(propose_classes_ids)) - if num_class_proposals == 1: - semantc_mask[valid_mask] = propose_classes_ids[0].numpy() - continue - top_1_propose_class_ids = torch.bincount(propose_classes_ids.flatten()).topk(1).indices - semantc_mask[valid_mask] = top_1_propose_class_ids.numpy() - logger.info("AutoSAM strengthening process end") - return semantc_mask - - - def random_segmentation(self, img: Image.Image) -> Tuple[List[Image.Image], str]: - logger.info("AutoSAM generating random segmentation for EditAnything") - img_np = np.array(img.convert("RGB")) - annotations = self.auto_sam.generate(img_np) - logger.info(f"AutoSAM generated {len(annotations)} masks") - H, W, _ = img_np.shape - color_map = np.zeros((H, W, 3), dtype=np.uint8) - detected_map_tmp = np.zeros((H, W), dtype=np.uint16) - for idx, annotation in enumerate(annotations): - current_seg = annotation['segmentation'] - color_map[current_seg] = np.random.randint(0, 255, (3)) - detected_map_tmp[current_seg] = idx + 1 - detected_map = np.zeros((detected_map_tmp.shape[0], detected_map_tmp.shape[1], 3)) - detected_map[:, :, 0] = detected_map_tmp % 256 - detected_map[:, :, 1] = detected_map_tmp // 256 - try: - from scripts.processor import HWC3 - except: - return [], "ControlNet extension not found." - detected_map = HWC3(detected_map.astype(np.uint8)) - logger.info("AutoSAM generation process end") - return [self.blend_image_and_seg(img_np, color_map), Image.fromarray(color_map), Image.fromarray(detected_map)], \ - "Random segmentation done. Left above (0) is blended image, right above (1) is random segmentation, left below (2) is Edit-Anything control input." - - - def layer_single_image(self, layout_input_image: Image.Image, layout_output_path: str) -> None: - img_np = np.array(layout_input_image.convert("RGB")) - annotations = self.auto_sam.generate(img_np) - logger.info(f"AutoSAM generated {len(annotations)} annotations") - annotations = sorted(annotations, key=lambda x: x['area']) - for idx, annotation in enumerate(annotations): - img_tmp = np.zeros((img_np.shape[0], img_np.shape[1], 3)) - img_tmp[annotation['segmentation']] = img_np[annotation['segmentation']] - img_np[annotation['segmentation']] = np.array([0, 0, 0]) - img_tmp = Image.fromarray(img_tmp.astype(np.uint8)) - img_tmp.save(os.path.join(layout_output_path, f"{idx}.png")) - img_np = Image.fromarray(img_np.astype(np.uint8)) - img_np.save(os.path.join(layout_output_path, "leftover.png")) - - - def image_layer(self, layout_input_image_or_path, layout_output_path: str) -> str: - if isinstance(layout_input_image_or_path, str): - logger.info("Image layer division batch processing") - all_files = glob.glob(os.path.join(layout_input_image_or_path, "*")) - for image_index, input_image_file in enumerate(all_files): - logger.info(f"Processing {image_index}/{len(all_files)} {input_image_file}") - try: - input_image = Image.open(input_image_file) - output_directory = os.path.join(layout_output_path, os.path.splitext(os.path.basename(input_image_file))[0]) - from pathlib import Path - Path(output_directory).mkdir(exist_ok=True) - except: - logger.warn(f"File {input_image_file} not image, skipped.") - continue - self.layer_single_image(input_image, output_directory) - else: - self.layer_single_image(layout_input_image_or_path, layout_output_path) - return "Done" - - - def semantic_segmentation(self, input_image: Image.Image, annotator_name: str, processor_res: int, - use_pixel_perfect: bool, resize_mode: int, target_W: int, target_H: int) -> Tuple[List[Image.Image], str]: - if input_image is None: - return [], "No input image." - if "seg" in annotator_name: - try: - from scripts.processor import uniformer, oneformer_coco, oneformer_ade20k - from scripts.external_code import pixel_perfect_resolution - oneformers = { - "ade20k": oneformer_ade20k, - "coco": oneformer_coco - } - except: - return [], "ControlNet extension not found." - input_image_np = np.array(input_image) - if use_pixel_perfect: - processor_res = pixel_perfect_resolution(input_image_np, resize_mode, target_W, target_H) - logger.info("Generating semantic segmentation without SAM") - if annotator_name == "seg_ufade20k": - original_semantic = uniformer(input_image_np, processor_res) - else: - dataset = annotator_name.split('_')[-1][2:] - original_semantic = oneformers[dataset](input_image_np, processor_res) - logger.info("Generating semantic segmentation with SAM") - sam_semantic = self.strengthen_semmantic_seg(np.array(original_semantic), input_image_np) - output_gallery = [original_semantic, sam_semantic, self.blend_image_and_seg(input_image, original_semantic), self.blend_image_and_seg(input_image, sam_semantic)] - return output_gallery, f"Done. Left is segmentation before SAM, right is segmentation after SAM." - else: - return self.random_segmentation(input_image) - - - def categorical_mask_image(self, crop_processor: str, crop_processor_res: int, crop_category_input: List[int], crop_input_image: Image.Image, - crop_pixel_perfect: bool, crop_resize_mode: int, target_W: int, target_H: int) -> Tuple[np.ndarray, Image.Image]: - if crop_input_image is None: - return "No input image." - try: - from scripts.processor import uniformer, oneformer_coco, oneformer_ade20k - from scripts.external_code import pixel_perfect_resolution - oneformers = { - "ade20k": oneformer_ade20k, - "coco": oneformer_coco - } - except: - return [], "ControlNet extension not found." - filter_classes = crop_category_input - if len(filter_classes) == 0: - return "No class selected." - try: - filter_classes = [int(i) for i in filter_classes] - except: - return "Illegal class id. You may have input some string." - crop_input_image_np = np.array(crop_input_image) - if crop_pixel_perfect: - crop_processor_res = pixel_perfect_resolution(crop_input_image_np, crop_resize_mode, target_W, target_H) - crop_input_image_copy = copy.deepcopy(crop_input_image) - logger.info(f"Generating categories with processor {crop_processor}") - if crop_processor == "seg_ufade20k": - original_semantic = uniformer(crop_input_image_np, crop_processor_res) - else: - dataset = crop_processor.split('_')[-1][2:] - original_semantic = oneformers[dataset](crop_input_image_np, crop_processor_res) - sam_semantic = self.strengthen_semmantic_seg(np.array(original_semantic), crop_input_image_np) - mask = np.zeros(sam_semantic.shape, dtype=np.bool_) - from scripts.sam_config import SEMANTIC_CATEGORIES - for i in filter_classes: - mask[np.equal(sam_semantic, SEMANTIC_CATEGORIES[crop_processor][i])] = True - return mask, crop_input_image_copy - - - def register_auto_sam(self, sam, - auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, - auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, - auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, - auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, auto_sam_output_mode): - self.auto_sam = SamAutomaticMaskGeneratorHQ( - sam, auto_sam_points_per_side, auto_sam_points_per_batch, - auto_sam_pred_iou_thresh, auto_sam_stability_score_thresh, - auto_sam_stability_score_offset, auto_sam_box_nms_thresh, - auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, - auto_sam_crop_n_points_downscale_factor, None, - auto_sam_min_mask_region_area, auto_sam_output_mode) diff --git a/thirdparty/fastsam/__init__.py b/thirdparty/fastsam/__init__.py new file mode 100644 index 0000000..7f99d8d --- /dev/null +++ b/thirdparty/fastsam/__init__.py @@ -0,0 +1,9 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from .model import FastSAM +from .predict import FastSAMPredictor +from .prompt import FastSAMPrompt +# from .val import FastSAMValidator +from .decoder import FastSAMDecoder + +__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMDecoder' diff --git a/thirdparty/fastsam/decoder.py b/thirdparty/fastsam/decoder.py new file mode 100644 index 0000000..5b92ed8 --- /dev/null +++ b/thirdparty/fastsam/decoder.py @@ -0,0 +1,131 @@ +from .model import FastSAM +import numpy as np +from PIL import Image +import clip +from typing import Optional, List, Tuple, Union + + +class FastSAMDecoder: + def __init__( + self, + model: FastSAM, + device: str='cpu', + conf: float=0.4, + iou: float=0.9, + imgsz: int=1024, + retina_masks: bool=True, + ): + self.model = model + self.device = device + self.retina_masks = retina_masks + self.imgsz = imgsz + self.conf = conf + self.iou = iou + self.image = None + self.image_embedding = None + + def run_encoder(self, image): + if isinstance(image,str): + image = np.array(Image.open(image)) + self.image = image + image_embedding = self.model( + self.image, + device=self.device, + retina_masks=self.retina_masks, + imgsz=self.imgsz, + conf=self.conf, + iou=self.iou + ) + return image_embedding[0].numpy() + + def run_decoder( + self, + image_embedding, + point_prompt: Optional[np.ndarray]=None, + point_label: Optional[np.ndarray]=None, + box_prompt: Optional[np.ndarray]=None, + text_prompt: Optional[str]=None, + )->np.ndarray: + self.image_embedding = image_embedding + if point_prompt is not None: + ann = self.point_prompt(points=point_prompt, pointlabel=point_label) + return ann + elif box_prompt is not None: + ann = self.box_prompt(bbox=box_prompt) + return ann + elif text_prompt is not None: + ann = self.text_prompt(text=text_prompt) + return ann + else: + return None + + def box_prompt(self, bbox): + assert (bbox[2] != 0 and bbox[3] != 0) + masks = self.image_embedding.masks.data + target_height = self.ori_img.shape[0] + target_width = self.ori_img.shape[1] + h = masks.shape[1] + w = masks.shape[2] + if h != target_height or w != target_width: + bbox = [ + int(bbox[0] * w / target_width), + int(bbox[1] * h / target_height), + int(bbox[2] * w / target_width), + int(bbox[3] * h / target_height), ] + bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 + bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 + bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w + bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h + + # IoUs = torch.zeros(len(masks), dtype=torch.float32) + bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) + + masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2)) + orig_masks_area = np.sum(masks, axis=(1, 2)) + + union = bbox_area + orig_masks_area - masks_area + IoUs = masks_area / union + max_iou_index = np.argmax(IoUs) + + return np.array([masks[max_iou_index].cpu().numpy()]) + + def point_prompt(self, points, pointlabel): # numpy + + masks = self._format_results(self.results[0], 0) + target_height = self.ori_img.shape[0] + target_width = self.ori_img.shape[1] + h = masks[0]['segmentation'].shape[0] + w = masks[0]['segmentation'].shape[1] + if h != target_height or w != target_width: + points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] + onemask = np.zeros((h, w)) + masks = sorted(masks, key=lambda x: x['area'], reverse=True) + for i, annotation in enumerate(masks): + if type(annotation) == dict: + mask = annotation['segmentation'] + else: + mask = annotation + for i, point in enumerate(points): + if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: + onemask[mask] = 1 + if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: + onemask[mask] = 0 + onemask = onemask >= 1 + return np.array([onemask]) + + def _format_results(self, result, filter=0): + annotations = [] + n = len(result.masks.data) + for i in range(n): + annotation = {} + mask = result.masks.data[i] == 1.0 + + if np.sum(mask) < filter: + continue + annotation['id'] = i + annotation['segmentation'] = mask + annotation['bbox'] = result.boxes.data[i] + annotation['score'] = result.boxes.conf[i] + annotation['area'] = annotation['segmentation'].sum() + annotations.append(annotation) + return annotations \ No newline at end of file diff --git a/thirdparty/fastsam/model.py b/thirdparty/fastsam/model.py new file mode 100644 index 0000000..5450cb3 --- /dev/null +++ b/thirdparty/fastsam/model.py @@ -0,0 +1,104 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +""" +YOLO-NAS model interface. + +Usage - Predict: + from ultralytics import FastSAM + + model = FastSAM('last.pt') + results = model.predict('ultralytics/assets/bus.jpg') +""" + +from ultralytics.yolo.cfg import get_cfg +from ultralytics.yolo.engine.exporter import Exporter +from ultralytics.yolo.engine.model import YOLO +from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir +from ultralytics.yolo.utils.checks import check_imgsz + +from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode +from .predict import FastSAMPredictor + + +class FastSAM(YOLO): + + @smart_inference_mode() + def predict(self, source=None, stream=False, **kwargs): + """ + Perform prediction using the YOLO model. + + Args: + source (str | int | PIL | np.ndarray): The source of the image to make predictions on. + Accepts all source types accepted by the YOLO model. + stream (bool): Whether to stream the predictions or not. Defaults to False. + **kwargs : Additional keyword arguments passed to the predictor. + Check the 'configuration' section in the documentation for all available options. + + Returns: + (List[ultralytics.yolo.engine.results.Results]): The prediction results. + """ + if source is None: + source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' + LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") + overrides = self.overrides.copy() + overrides['conf'] = 0.25 + overrides.update(kwargs) # prefer kwargs + overrides['mode'] = kwargs.get('mode', 'predict') + assert overrides['mode'] in ['track', 'predict'] + overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python + self.predictor = FastSAMPredictor(overrides=overrides) + self.predictor.setup_model(model=self.model, verbose=False) + + return self.predictor(source, stream=stream) + + def train(self, **kwargs): + """Function trains models but raises an error as FastSAM models do not support training.""" + raise NotImplementedError("FastSAM models don't support training") + + def val(self, **kwargs): + """Run validation given dataset.""" + overrides = dict(task='segment', mode='val') + overrides.update(kwargs) # prefer kwargs + args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) + args.imgsz = check_imgsz(args.imgsz, max_dim=1) + validator = FastSAM(args=args) + validator(model=self.model) + self.metrics = validator.metrics + return validator.metrics + + @smart_inference_mode() + def export(self, **kwargs): + """ + Export model. + + Args: + **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs + """ + overrides = dict(task='detect') + overrides.update(kwargs) + overrides['mode'] = 'export' + args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) + args.task = self.task + if args.imgsz == DEFAULT_CFG.imgsz: + args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed + if args.batch == DEFAULT_CFG.batch: + args.batch = 1 # default to 1 if not modified + return Exporter(overrides=args)(model=self.model) + + def info(self, detailed=False, verbose=True): + """ + Logs model info. + + Args: + detailed (bool): Show detailed information about model. + verbose (bool): Controls verbosity. + """ + return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) + + def __call__(self, source=None, stream=False, **kwargs): + """Calls the 'predict' function with given arguments to perform object detection.""" + return self.predict(source, stream, **kwargs) + + def __getattr__(self, attr): + """Raises error if object has no requested attribute.""" + name = self.__class__.__name__ + raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") diff --git a/thirdparty/fastsam/predict.py b/thirdparty/fastsam/predict.py new file mode 100644 index 0000000..5e3ca40 --- /dev/null +++ b/thirdparty/fastsam/predict.py @@ -0,0 +1,52 @@ +import torch + +from ultralytics.yolo.engine.results import Results +from ultralytics.yolo.utils import DEFAULT_CFG, ops +from ultralytics.yolo.v8.detect.predict import DetectionPredictor +from .utils import bbox_iou + +class FastSAMPredictor(DetectionPredictor): + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + super().__init__(cfg, overrides, _callbacks) + self.args.task = 'segment' + + def postprocess(self, preds, img, orig_imgs): + """TODO: filter by classes.""" + p = ops.non_max_suppression(preds[0], + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + nc=len(self.model.names), + classes=self.args.classes) + + full_box = torch.zeros_like(p[0][0]) + full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 + full_box = full_box.view(1, -1) + critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:]) + if critical_iou_index.numel() != 0: + full_box[0][4] = p[0][critical_iou_index][:,4] + full_box[0][6:] = p[0][critical_iou_index][:,6:] + p[0][critical_iou_index] = full_box + + results = [] + proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported + for i, pred in enumerate(p): + orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs + path = self.batch[0] + img_path = path[i] if isinstance(path, list) else path + if not len(pred): # save empty boxes + results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])) + continue + if self.args.retina_masks: + if not isinstance(orig_imgs, torch.Tensor): + pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC + else: + masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC + if not isinstance(orig_imgs, torch.Tensor): + pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + results.append( + Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) + return results diff --git a/thirdparty/fastsam/prompt.py b/thirdparty/fastsam/prompt.py new file mode 100644 index 0000000..e67143a --- /dev/null +++ b/thirdparty/fastsam/prompt.py @@ -0,0 +1,417 @@ +import os +import sys +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +from PIL import Image + + + +try: + import clip # for linear_assignment + +except (ImportError, AssertionError, AttributeError): + from ultralytics.yolo.utils.checks import check_requirements + + check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source + import clip + + +class FastSAMPrompt: + + def __init__(self, img_path, results, device='cuda') -> None: + # self.img_path = img_path + self.device = device + self.results = results + self.img_path = img_path + self.ori_img = cv2.imread(img_path) + + def _segment_image(self, image, bbox): + image_array = np.array(image) + segmented_image_array = np.zeros_like(image_array) + x1, y1, x2, y2 = bbox + segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2] + segmented_image = Image.fromarray(segmented_image_array) + black_image = Image.new('RGB', image.size, (255, 255, 255)) + # transparency_mask = np.zeros_like((), dtype=np.uint8) + transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8) + transparency_mask[y1:y2, x1:x2] = 255 + transparency_mask_image = Image.fromarray(transparency_mask, mode='L') + black_image.paste(segmented_image, mask=transparency_mask_image) + return black_image + + def _format_results(self, result, filter=0): + annotations = [] + n = len(result.masks.data) + for i in range(n): + annotation = {} + mask = result.masks.data[i] == 1.0 + + if torch.sum(mask) < filter: + continue + annotation['id'] = i + annotation['segmentation'] = mask.cpu().numpy() + annotation['bbox'] = result.boxes.data[i] + annotation['score'] = result.boxes.conf[i] + annotation['area'] = annotation['segmentation'].sum() + annotations.append(annotation) + return annotations + + def filter_masks(annotations): # filte the overlap mask + annotations.sort(key=lambda x: x['area'], reverse=True) + to_remove = set() + for i in range(0, len(annotations)): + a = annotations[i] + for j in range(i + 1, len(annotations)): + b = annotations[j] + if i != j and j not in to_remove: + # check if + if b['area'] < a['area']: + if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8: + to_remove.add(j) + + return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove + + def _get_bbox_from_mask(self, mask): + mask = mask.astype(np.uint8) + contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + x1, y1, w, h = cv2.boundingRect(contours[0]) + x2, y2 = x1 + w, y1 + h + if len(contours) > 1: + for b in contours: + x_t, y_t, w_t, h_t = cv2.boundingRect(b) + # Merge multiple bounding boxes into one. + x1 = min(x1, x_t) + y1 = min(y1, y_t) + x2 = max(x2, x_t + w_t) + y2 = max(y2, y_t + h_t) + h = y2 - y1 + w = x2 - x1 + return [x1, y1, x2, y2] + + def plot(self, + annotations, + output, + bboxes=None, + points=None, + point_label=None, + mask_random_color=True, + better_quality=True, + retina=False, + withContours=True): + if isinstance(annotations[0], dict): + annotations = [annotation['segmentation'] for annotation in annotations] + result_name = os.path.basename(self.img_path) + image = self.ori_img + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + original_h = image.shape[0] + original_w = image.shape[1] + if sys.platform == "darwin": + plt.switch_backend("TkAgg") + plt.figure(figsize=(original_w / 100, original_h / 100)) + # Add subplot with no margin. + plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + plt.margins(0, 0) + plt.gca().xaxis.set_major_locator(plt.NullLocator()) + plt.gca().yaxis.set_major_locator(plt.NullLocator()) + + plt.imshow(image) + if better_quality: + if isinstance(annotations[0], torch.Tensor): + annotations = np.array(annotations.cpu()) + for i, mask in enumerate(annotations): + mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) + annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) + if self.device == 'cpu': + annotations = np.array(annotations) + self.fast_show_mask( + annotations, + plt.gca(), + random_color=mask_random_color, + bboxes=bboxes, + points=points, + pointlabel=point_label, + retinamask=retina, + target_height=original_h, + target_width=original_w, + ) + else: + if isinstance(annotations[0], np.ndarray): + annotations = torch.from_numpy(annotations) + self.fast_show_mask_gpu( + annotations, + plt.gca(), + random_color=mask_random_color, + bboxes=bboxes, + points=points, + pointlabel=point_label, + retinamask=retina, + target_height=original_h, + target_width=original_w, + ) + if isinstance(annotations, torch.Tensor): + annotations = annotations.cpu().numpy() + if withContours: + contour_all = [] + temp = np.zeros((original_h, original_w, 1)) + for i, mask in enumerate(annotations): + if type(mask) == dict: + mask = mask['segmentation'] + annotation = mask.astype(np.uint8) + if not retina: + annotation = cv2.resize( + annotation, + (original_w, original_h), + interpolation=cv2.INTER_NEAREST, + ) + contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + contour_all.append(contour) + cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) + color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8]) + contour_mask = temp / 255 * color.reshape(1, 1, -1) + plt.imshow(contour_mask) + + save_path = output + if not os.path.exists(save_path): + os.makedirs(save_path) + plt.axis('off') + fig = plt.gcf() + plt.draw() + + try: + buf = fig.canvas.tostring_rgb() + except AttributeError: + fig.canvas.draw() + buf = fig.canvas.tostring_rgb() + cols, rows = fig.canvas.get_width_height() + img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3) + cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)) + + # CPU post process + def fast_show_mask( + self, + annotation, + ax, + random_color=False, + bboxes=None, + points=None, + pointlabel=None, + retinamask=True, + target_height=960, + target_width=960, + ): + msak_sum = annotation.shape[0] + height = annotation.shape[1] + weight = annotation.shape[2] + #Sort annotations based on area. + areas = np.sum(annotation, axis=(1, 2)) + sorted_indices = np.argsort(areas) + annotation = annotation[sorted_indices] + + index = (annotation != 0).argmax(axis=0) + if random_color: + color = np.random.random((msak_sum, 1, 1, 3)) + else: + color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255]) + transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6 + visual = np.concatenate([color, transparency], axis=-1) + mask_image = np.expand_dims(annotation, -1) * visual + + show = np.zeros((height, weight, 4)) + h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij') + indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) + # Use vectorized indexing to update the values of 'show'. + show[h_indices, w_indices, :] = mask_image[indices] + if bboxes is not None: + for bbox in bboxes: + x1, y1, x2, y2 = bbox + ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) + # draw point + if points is not None: + plt.scatter( + [point[0] for i, point in enumerate(points) if pointlabel[i] == 1], + [point[1] for i, point in enumerate(points) if pointlabel[i] == 1], + s=20, + c='y', + ) + plt.scatter( + [point[0] for i, point in enumerate(points) if pointlabel[i] == 0], + [point[1] for i, point in enumerate(points) if pointlabel[i] == 0], + s=20, + c='m', + ) + + if not retinamask: + show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST) + ax.imshow(show) + + def fast_show_mask_gpu( + self, + annotation, + ax, + random_color=False, + bboxes=None, + points=None, + pointlabel=None, + retinamask=True, + target_height=960, + target_width=960, + ): + msak_sum = annotation.shape[0] + height = annotation.shape[1] + weight = annotation.shape[2] + areas = torch.sum(annotation, dim=(1, 2)) + sorted_indices = torch.argsort(areas, descending=False) + annotation = annotation[sorted_indices] + # Find the index of the first non-zero value at each position. + index = (annotation != 0).to(torch.long).argmax(dim=0) + if random_color: + color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device) + else: + color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([ + 30 / 255, 144 / 255, 255 / 255]).to(annotation.device) + transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6 + visual = torch.cat([color, transparency], dim=-1) + mask_image = torch.unsqueeze(annotation, -1) * visual + # Select data according to the index. The index indicates which batch's data to choose at each position, converting the mask_image into a single batch form. + show = torch.zeros((height, weight, 4)).to(annotation.device) + h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij') + indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) + # Use vectorized indexing to update the values of 'show'. + show[h_indices, w_indices, :] = mask_image[indices] + show_cpu = show.cpu().numpy() + if bboxes is not None: + for bbox in bboxes: + x1, y1, x2, y2 = bbox + ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) + # draw point + if points is not None: + plt.scatter( + [point[0] for i, point in enumerate(points) if pointlabel[i] == 1], + [point[1] for i, point in enumerate(points) if pointlabel[i] == 1], + s=20, + c='y', + ) + plt.scatter( + [point[0] for i, point in enumerate(points) if pointlabel[i] == 0], + [point[1] for i, point in enumerate(points) if pointlabel[i] == 0], + s=20, + c='m', + ) + if not retinamask: + show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST) + ax.imshow(show_cpu) + + # clip + @torch.no_grad() + def retrieve(self, model, preprocess, elements, search_text: str, device) -> int: + preprocessed_images = [preprocess(image).to(device) for image in elements] + tokenized_text = clip.tokenize([search_text]).to(device) + stacked_images = torch.stack(preprocessed_images) + image_features = model.encode_image(stacked_images) + text_features = model.encode_text(tokenized_text) + image_features /= image_features.norm(dim=-1, keepdim=True) + text_features /= text_features.norm(dim=-1, keepdim=True) + probs = 100.0 * image_features @ text_features.T + return probs[:, 0].softmax(dim=0) + + def _crop_image(self, format_results): + + image = Image.fromarray(cv2.cvtColor(self.ori_img, cv2.COLOR_BGR2RGB)) + ori_w, ori_h = image.size + annotations = format_results + mask_h, mask_w = annotations[0]['segmentation'].shape + if ori_w != mask_w or ori_h != mask_h: + image = image.resize((mask_w, mask_h)) + cropped_boxes = [] + cropped_images = [] + not_crop = [] + filter_id = [] + # annotations, _ = filter_masks(annotations) + # filter_id = list(_) + for _, mask in enumerate(annotations): + if np.sum(mask['segmentation']) <= 100: + filter_id.append(_) + continue + bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox + cropped_boxes.append(self._segment_image(image, bbox)) + # cropped_boxes.append(segment_image(image,mask["segmentation"])) + cropped_images.append(bbox) # Save the bounding box of the cropped image. + + return cropped_boxes, cropped_images, not_crop, filter_id, annotations + + def box_prompt(self, bbox=None, bboxes=None): + + assert bbox or bboxes + if bboxes is None: + bboxes = [bbox] + max_iou_index = [] + for bbox in bboxes: + assert (bbox[2] != 0 and bbox[3] != 0) + masks = self.results[0].masks.data + target_height = self.ori_img.shape[0] + target_width = self.ori_img.shape[1] + h = masks.shape[1] + w = masks.shape[2] + if h != target_height or w != target_width: + bbox = [ + int(bbox[0] * w / target_width), + int(bbox[1] * h / target_height), + int(bbox[2] * w / target_width), + int(bbox[3] * h / target_height), ] + bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 + bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 + bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w + bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h + + # IoUs = torch.zeros(len(masks), dtype=torch.float32) + bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) + + masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2)) + orig_masks_area = torch.sum(masks, dim=(1, 2)) + + union = bbox_area + orig_masks_area - masks_area + IoUs = masks_area / union + max_iou_index.append(int(torch.argmax(IoUs))) + max_iou_index = list(set(max_iou_index)) + return np.array(masks[max_iou_index].cpu().numpy()) + + def point_prompt(self, points, pointlabel): # numpy + + masks = self._format_results(self.results[0], 0) + target_height = self.ori_img.shape[0] + target_width = self.ori_img.shape[1] + h = masks[0]['segmentation'].shape[0] + w = masks[0]['segmentation'].shape[1] + if h != target_height or w != target_width: + points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] + onemask = np.zeros((h, w)) + masks = sorted(masks, key=lambda x: x['area'], reverse=True) + for i, annotation in enumerate(masks): + if type(annotation) == dict: + mask = annotation['segmentation'] + else: + mask = annotation + for i, point in enumerate(points): + if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: + onemask[mask] = 1 + if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: + onemask[mask] = 0 + onemask = onemask >= 1 + return np.array([onemask]) + + def text_prompt(self, text): + format_results = self._format_results(self.results[0], 0) + cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) + clip_model, preprocess = clip.load('ViT-B/32', device=self.device) + scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device) + max_idx = scores.argsort() + max_idx = max_idx[-1] + max_idx += sum(np.array(filter_id) <= int(max_idx)) + return np.array([annotations[max_idx]['segmentation']]) + + def everything_prompt(self): + return self.results[0].masks.data + \ No newline at end of file diff --git a/thirdparty/fastsam/utils.py b/thirdparty/fastsam/utils.py new file mode 100644 index 0000000..d828c2c --- /dev/null +++ b/thirdparty/fastsam/utils.py @@ -0,0 +1,66 @@ +import torch + +def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20): + '''Adjust bounding boxes to stick to image border if they are within a certain threshold. + Args: + boxes: (n, 4) + image_shape: (height, width) + threshold: pixel threshold + Returns: + adjusted_boxes: adjusted bounding boxes + ''' + + # Image dimensions + h, w = image_shape + + # Adjust boxes + boxes[:, 0] = torch.where(boxes[:, 0] < threshold, 0, boxes[:, 0]) # x1 + boxes[:, 1] = torch.where(boxes[:, 1] < threshold, 0, boxes[:, 1]) # y1 + boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, w, boxes[:, 2]) # x2 + boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, h, boxes[:, 3]) # y2 + + return boxes + +def convert_box_xywh_to_xyxy(box): + x1 = box[0] + y1 = box[1] + x2 = box[0] + box[2] + y2 = box[1] + box[3] + return [x1, y1, x2, y2] + +def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False): + '''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes. + Args: + box1: (4, ) + boxes: (n, 4) + Returns: + high_iou_indices: Indices of boxes with IoU > thres + ''' + boxes = adjust_bboxes_to_image_border(boxes, image_shape) + # obtain coordinates for intersections + x1 = torch.max(box1[0], boxes[:, 0]) + y1 = torch.max(box1[1], boxes[:, 1]) + x2 = torch.min(box1[2], boxes[:, 2]) + y2 = torch.min(box1[3], boxes[:, 3]) + + # compute the area of intersection + intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) + + # compute the area of both individual boxes + box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) + box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + # compute the area of union + union = box1_area + box2_area - intersection + + # compute the IoU + iou = intersection / union # Should be shape (n, ) + if raw_output: + if iou.numel() == 0: + return 0 + return iou + + # get indices of boxes with IoU > thres + high_iou_indices = torch.nonzero(iou > iou_thres).flatten() + + return high_iou_indices \ No newline at end of file diff --git a/local_groundingdino/datasets/__init__.py b/thirdparty/groundingdino/datasets/__init__.py similarity index 100% rename from local_groundingdino/datasets/__init__.py rename to thirdparty/groundingdino/datasets/__init__.py diff --git a/local_groundingdino/datasets/transforms.py b/thirdparty/groundingdino/datasets/transforms.py similarity index 100% rename from local_groundingdino/datasets/transforms.py rename to thirdparty/groundingdino/datasets/transforms.py diff --git a/local_groundingdino/models/GroundingDINO/__init__.py b/thirdparty/groundingdino/models/GroundingDINO/__init__.py similarity index 100% rename from local_groundingdino/models/GroundingDINO/__init__.py rename to thirdparty/groundingdino/models/GroundingDINO/__init__.py diff --git a/local_groundingdino/models/GroundingDINO/backbone/__init__.py b/thirdparty/groundingdino/models/GroundingDINO/backbone/__init__.py similarity index 100% rename from local_groundingdino/models/GroundingDINO/backbone/__init__.py rename to thirdparty/groundingdino/models/GroundingDINO/backbone/__init__.py diff --git a/local_groundingdino/models/GroundingDINO/backbone/backbone.py b/thirdparty/groundingdino/models/GroundingDINO/backbone/backbone.py similarity index 100% rename from local_groundingdino/models/GroundingDINO/backbone/backbone.py rename to thirdparty/groundingdino/models/GroundingDINO/backbone/backbone.py diff --git a/local_groundingdino/models/GroundingDINO/backbone/position_encoding.py b/thirdparty/groundingdino/models/GroundingDINO/backbone/position_encoding.py similarity index 100% rename from local_groundingdino/models/GroundingDINO/backbone/position_encoding.py rename to thirdparty/groundingdino/models/GroundingDINO/backbone/position_encoding.py diff --git a/local_groundingdino/models/GroundingDINO/backbone/swin_transformer.py b/thirdparty/groundingdino/models/GroundingDINO/backbone/swin_transformer.py similarity index 100% rename from local_groundingdino/models/GroundingDINO/backbone/swin_transformer.py rename to thirdparty/groundingdino/models/GroundingDINO/backbone/swin_transformer.py diff --git a/local_groundingdino/models/GroundingDINO/bertwarper.py b/thirdparty/groundingdino/models/GroundingDINO/bertwarper.py similarity index 100% rename from local_groundingdino/models/GroundingDINO/bertwarper.py rename to thirdparty/groundingdino/models/GroundingDINO/bertwarper.py diff --git a/local_groundingdino/models/GroundingDINO/fuse_modules.py b/thirdparty/groundingdino/models/GroundingDINO/fuse_modules.py similarity index 100% rename from local_groundingdino/models/GroundingDINO/fuse_modules.py rename to thirdparty/groundingdino/models/GroundingDINO/fuse_modules.py diff --git a/local_groundingdino/models/GroundingDINO/groundingdino.py b/thirdparty/groundingdino/models/GroundingDINO/groundingdino.py similarity index 100% rename from local_groundingdino/models/GroundingDINO/groundingdino.py rename to thirdparty/groundingdino/models/GroundingDINO/groundingdino.py diff --git a/local_groundingdino/models/GroundingDINO/ms_deform_attn.py b/thirdparty/groundingdino/models/GroundingDINO/ms_deform_attn.py similarity index 100% rename from local_groundingdino/models/GroundingDINO/ms_deform_attn.py rename to thirdparty/groundingdino/models/GroundingDINO/ms_deform_attn.py diff --git a/local_groundingdino/models/GroundingDINO/transformer.py b/thirdparty/groundingdino/models/GroundingDINO/transformer.py similarity index 100% rename from local_groundingdino/models/GroundingDINO/transformer.py rename to thirdparty/groundingdino/models/GroundingDINO/transformer.py diff --git a/local_groundingdino/models/GroundingDINO/transformer_vanilla.py b/thirdparty/groundingdino/models/GroundingDINO/transformer_vanilla.py similarity index 100% rename from local_groundingdino/models/GroundingDINO/transformer_vanilla.py rename to thirdparty/groundingdino/models/GroundingDINO/transformer_vanilla.py diff --git a/local_groundingdino/models/GroundingDINO/utils.py b/thirdparty/groundingdino/models/GroundingDINO/utils.py similarity index 100% rename from local_groundingdino/models/GroundingDINO/utils.py rename to thirdparty/groundingdino/models/GroundingDINO/utils.py diff --git a/local_groundingdino/models/__init__.py b/thirdparty/groundingdino/models/__init__.py similarity index 100% rename from local_groundingdino/models/__init__.py rename to thirdparty/groundingdino/models/__init__.py diff --git a/local_groundingdino/models/registry.py b/thirdparty/groundingdino/models/registry.py similarity index 100% rename from local_groundingdino/models/registry.py rename to thirdparty/groundingdino/models/registry.py diff --git a/local_groundingdino/util/__init__.py b/thirdparty/groundingdino/util/__init__.py similarity index 100% rename from local_groundingdino/util/__init__.py rename to thirdparty/groundingdino/util/__init__.py diff --git a/local_groundingdino/util/box_ops.py b/thirdparty/groundingdino/util/box_ops.py similarity index 100% rename from local_groundingdino/util/box_ops.py rename to thirdparty/groundingdino/util/box_ops.py diff --git a/local_groundingdino/util/get_tokenlizer.py b/thirdparty/groundingdino/util/get_tokenlizer.py similarity index 100% rename from local_groundingdino/util/get_tokenlizer.py rename to thirdparty/groundingdino/util/get_tokenlizer.py diff --git a/local_groundingdino/util/inference.py b/thirdparty/groundingdino/util/inference.py similarity index 100% rename from local_groundingdino/util/inference.py rename to thirdparty/groundingdino/util/inference.py diff --git a/local_groundingdino/util/misc.py b/thirdparty/groundingdino/util/misc.py similarity index 100% rename from local_groundingdino/util/misc.py rename to thirdparty/groundingdino/util/misc.py diff --git a/local_groundingdino/util/slconfig.py b/thirdparty/groundingdino/util/slconfig.py similarity index 100% rename from local_groundingdino/util/slconfig.py rename to thirdparty/groundingdino/util/slconfig.py diff --git a/local_groundingdino/util/slio.py b/thirdparty/groundingdino/util/slio.py similarity index 100% rename from local_groundingdino/util/slio.py rename to thirdparty/groundingdino/util/slio.py diff --git a/local_groundingdino/util/utils.py b/thirdparty/groundingdino/util/utils.py similarity index 100% rename from local_groundingdino/util/utils.py rename to thirdparty/groundingdino/util/utils.py diff --git a/mam/__init__.py b/thirdparty/mam/__init__.py similarity index 100% rename from mam/__init__.py rename to thirdparty/mam/__init__.py diff --git a/mam/m2m.py b/thirdparty/mam/m2m.py similarity index 90% rename from mam/m2m.py rename to thirdparty/mam/m2m.py index 670907c..d3ad32a 100644 --- a/mam/m2m.py +++ b/thirdparty/mam/m2m.py @@ -31,7 +31,12 @@ class SamM2M(Module): except: mam_url = "https://huggingface.co/conrevo/SAM4WebUI-Extension-Models/resolve/main/mam.pth" logger.info(f"Loading mam from url: {mam_url} to path: {ckpt_path}, device: {self.m2m_device}") - state_dict = torch.hub.load_state_dict_from_url(mam_url, ckpt_path, self.m2m_device) + try: + state_dict = torch.hub.load_state_dict_from_url(mam_url, ckpt_path, self.m2m_device) + except: + error_msg = f"Unable to connect to {mam_url}, thus unable to download Matting-Anything model." + logger.error(error_msg) + raise RuntimeError(error_msg) self.m2m.load_state_dict(state_dict) self.m2m.eval() @@ -39,6 +44,7 @@ class SamM2M(Module): def forward(self, features: torch.Tensor, image: torch.Tensor, low_res_masks: torch.Tensor, masks: torch.Tensor, ori_shape: torch.Tensor, pad_shape: torch.Tensor, guidance_mode: str): + logger.info("Applying Matting-Anything.") self.m2m.to(self.m2m_device) pred = self.m2m(features, image, low_res_masks) alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred['alpha_os1'], pred['alpha_os4'], pred['alpha_os8'] diff --git a/mam/m2ms/__init__.py b/thirdparty/mam/m2ms/__init__.py similarity index 100% rename from mam/m2ms/__init__.py rename to thirdparty/mam/m2ms/__init__.py diff --git a/mam/m2ms/conv_sam.py b/thirdparty/mam/m2ms/conv_sam.py similarity index 100% rename from mam/m2ms/conv_sam.py rename to thirdparty/mam/m2ms/conv_sam.py diff --git a/mam/ops.py b/thirdparty/mam/ops.py similarity index 100% rename from mam/ops.py rename to thirdparty/mam/ops.py diff --git a/mam/utils.py b/thirdparty/mam/utils.py similarity index 100% rename from mam/utils.py rename to thirdparty/mam/utils.py diff --git a/sam_hq/automatic.py b/thirdparty/sam_hq/automatic.py similarity index 100% rename from sam_hq/automatic.py rename to thirdparty/sam_hq/automatic.py diff --git a/sam_hq/build_sam_hq.py b/thirdparty/sam_hq/build_sam_hq.py similarity index 100% rename from sam_hq/build_sam_hq.py rename to thirdparty/sam_hq/build_sam_hq.py diff --git a/sam_hq/modeling/image_encoder.py b/thirdparty/sam_hq/modeling/image_encoder.py similarity index 100% rename from sam_hq/modeling/image_encoder.py rename to thirdparty/sam_hq/modeling/image_encoder.py diff --git a/sam_hq/modeling/mask_decoder_hq.py b/thirdparty/sam_hq/modeling/mask_decoder_hq.py similarity index 100% rename from sam_hq/modeling/mask_decoder_hq.py rename to thirdparty/sam_hq/modeling/mask_decoder_hq.py diff --git a/sam_hq/predictor.py b/thirdparty/sam_hq/predictor.py similarity index 100% rename from sam_hq/predictor.py rename to thirdparty/sam_hq/predictor.py