another save
parent
a4058c4854
commit
6488b4e2dd
|
|
@ -0,0 +1,284 @@
|
|||
from typing import List, Tuple, Union, Optional
|
||||
import os
|
||||
import glob
|
||||
import copy
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from segment_anything.modeling import Sam
|
||||
from thirdparty.fastsam import FastSAM
|
||||
from thirdparty.sam_hq.automatic import SamAutomaticMaskGeneratorHQ
|
||||
from sam_utils.segment import Segmentation
|
||||
from sam_utils.logger import logger
|
||||
from sam_utils.util import blend_image_and_seg
|
||||
|
||||
|
||||
class AutoSAM:
|
||||
"""Automatic segmentation."""
|
||||
|
||||
def __init__(self, sam: Segmentation) -> None:
|
||||
"""AutoSAM initialization.
|
||||
|
||||
Args:
|
||||
sam (Segmentation): global Segmentation instance.
|
||||
"""
|
||||
self.sam = sam
|
||||
self.auto_sam: Union[SamAutomaticMaskGeneratorHQ, FastSAM] = None
|
||||
self.fastsam_conf = None
|
||||
self.fastsam_iou = None
|
||||
|
||||
|
||||
def auto_generate(self, img: np.ndarray) -> List[dict]:
|
||||
"""Generate segmentation.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): input image.
|
||||
|
||||
Returns:
|
||||
List[dict]: list of segmentation masks.
|
||||
"""
|
||||
return self.auto_sam.generate(img) if type(self.auto_sam) == SamAutomaticMaskGeneratorHQ else \
|
||||
self.auto_sam(img, device=self.sam.sam_device, retina_masks=True, imgsz=1024, conf=self.fastsam_conf, iou=self.fastsam_iou)
|
||||
|
||||
|
||||
def strengthen_semantic_seg(self, class_ids: np.ndarray, img: np.ndarray) -> np.ndarray:
|
||||
# TODO: class_ids use multiple dimensions, categorical mask single and batch
|
||||
logger.info("AutoSAM strengthening semantic segmentation")
|
||||
from sam_utils.util import install_pycocotools
|
||||
install_pycocotools()
|
||||
import pycocotools.mask as maskUtils
|
||||
semantc_mask = copy.deepcopy(class_ids)
|
||||
annotations = self.auto_generate(img)
|
||||
annotations = sorted(annotations, key=lambda x: x['area'], reverse=True)
|
||||
logger.info(f"AutoSAM generated {len(annotations)} masks")
|
||||
for ann in annotations:
|
||||
valid_mask = torch.tensor(maskUtils.decode(ann['segmentation'])).bool()
|
||||
propose_classes_ids = torch.tensor(class_ids[valid_mask])
|
||||
num_class_proposals = len(torch.unique(propose_classes_ids))
|
||||
if num_class_proposals == 1:
|
||||
semantc_mask[valid_mask] = propose_classes_ids[0].numpy()
|
||||
continue
|
||||
top_1_propose_class_ids = torch.bincount(propose_classes_ids.flatten()).topk(1).indices
|
||||
semantc_mask[valid_mask] = top_1_propose_class_ids.numpy()
|
||||
logger.info("AutoSAM strengthening process end")
|
||||
return semantc_mask
|
||||
|
||||
|
||||
def random_segmentation(self, img: Image.Image) -> Tuple[List[Image.Image], str]:
|
||||
"""Random segmentation for EditAnything
|
||||
|
||||
Args:
|
||||
img (Image.Image): input image.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: ControlNet not installed.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Image.Image], str]: List of 3 displayed images and output message.
|
||||
"""
|
||||
logger.info("AutoSAM generating random segmentation for EditAnything")
|
||||
img_np = np.array(img.convert("RGB"))
|
||||
annotations = self.auto_generate(img_np)
|
||||
logger.info(f"AutoSAM generated {len(annotations)} masks")
|
||||
H, W, _ = img_np.shape
|
||||
color_map = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
detected_map_tmp = np.zeros((H, W), dtype=np.uint16)
|
||||
for idx, annotation in enumerate(annotations):
|
||||
current_seg = annotation['segmentation']
|
||||
color_map[current_seg] = np.random.randint(0, 255, (3))
|
||||
detected_map_tmp[current_seg] = idx + 1
|
||||
detected_map = np.zeros((detected_map_tmp.shape[0], detected_map_tmp.shape[1], 3))
|
||||
detected_map[:, :, 0] = detected_map_tmp % 256
|
||||
detected_map[:, :, 1] = detected_map_tmp // 256
|
||||
try:
|
||||
from scripts.processor import HWC3
|
||||
except:
|
||||
raise ModuleNotFoundError("ControlNet extension not found.")
|
||||
detected_map = HWC3(detected_map.astype(np.uint8))
|
||||
logger.info("AutoSAM generation process end")
|
||||
return [blend_image_and_seg(img_np, color_map), Image.fromarray(color_map), Image.fromarray(detected_map)], \
|
||||
"Random segmentation done. Left above (0) is blended image, right above (1) is random segmentation, left below (2) is Edit-Anything control input."
|
||||
|
||||
|
||||
def layout_single_image(self, input_image: Image.Image, output_path: str) -> None:
|
||||
"""Single image layout generation.
|
||||
|
||||
Args:
|
||||
input_image (Image.Image): input image.
|
||||
output_path (str): output path.
|
||||
"""
|
||||
img_np = np.array(input_image.convert("RGB"))
|
||||
annotations = self.auto_generate(img_np)
|
||||
logger.info(f"AutoSAM generated {len(annotations)} annotations")
|
||||
annotations = sorted(annotations, key=lambda x: x['area'])
|
||||
for idx, annotation in enumerate(annotations):
|
||||
img_tmp = np.zeros((img_np.shape[0], img_np.shape[1], 3))
|
||||
img_tmp[annotation['segmentation']] = img_np[annotation['segmentation']]
|
||||
img_np[annotation['segmentation']] = np.array([0, 0, 0])
|
||||
img_tmp = Image.fromarray(img_tmp.astype(np.uint8))
|
||||
img_tmp.save(os.path.join(output_path, f"{idx}.png"))
|
||||
img_np = Image.fromarray(img_np.astype(np.uint8))
|
||||
img_np.save(os.path.join(output_path, "leftover.png"))
|
||||
|
||||
|
||||
def layout(self, input_image_or_path: Union[str, Image.Image], output_path: str) -> str:
|
||||
"""Single or bath layout generation.
|
||||
|
||||
Args:
|
||||
input_image_or_path (Union[str, Image.Image]): input imag or path.
|
||||
output_path (str): output path.
|
||||
|
||||
Returns:
|
||||
str: generation message.
|
||||
"""
|
||||
if isinstance(input_image_or_path, str):
|
||||
logger.info("Image layer division batch processing")
|
||||
all_files = glob.glob(os.path.join(input_image_or_path, "*"))
|
||||
for image_index, input_image_file in enumerate(all_files):
|
||||
logger.info(f"Processing {image_index}/{len(all_files)} {input_image_file}")
|
||||
try:
|
||||
input_image = Image.open(input_image_file)
|
||||
output_directory = os.path.join(output_path, os.path.splitext(os.path.basename(input_image_file))[0])
|
||||
from pathlib import Path
|
||||
Path(output_directory).mkdir(exist_ok=True)
|
||||
except:
|
||||
logger.warn(f"File {input_image_file} not image, skipped.")
|
||||
continue
|
||||
self.layout_single_image(input_image, output_directory)
|
||||
else:
|
||||
self.layout_single_image(input_image_or_path, output_path)
|
||||
return "Done"
|
||||
|
||||
|
||||
def semantic_segmentation(self, input_image: Image.Image, annotator_name: str, processor_res: int,
|
||||
use_pixel_perfect: bool, resize_mode: int, target_W: int, target_H: int) -> Tuple[List[Image.Image], str]:
|
||||
"""Semantic segmentation enhanced by segment anything.
|
||||
|
||||
Args:
|
||||
input_image (Image.Image): input image.
|
||||
annotator_name (str): annotator name. Should be one of "seg_ufade20k"|"seg_ofade20k"|"seg_ofcoco".
|
||||
processor_res (int): processor resolution. Support 64-2048.
|
||||
use_pixel_perfect (bool): whether to use pixel perfect written by lllyasviel.
|
||||
resize_mode (int): resize mode for pixel perfect, should be 0|1|2.
|
||||
target_W (int): target width for pixel perfect.
|
||||
target_H (int): target height for pixel perfect.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: ControlNet not installed.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Image.Image], str]: list of 4 displayed images and message.
|
||||
"""
|
||||
assert input_image is not None, "No input image."
|
||||
if "seg" in annotator_name:
|
||||
try:
|
||||
from scripts.processor import uniformer, oneformer_coco, oneformer_ade20k
|
||||
from scripts.external_code import pixel_perfect_resolution
|
||||
oneformers = {
|
||||
"ade20k": oneformer_ade20k,
|
||||
"coco": oneformer_coco
|
||||
}
|
||||
except:
|
||||
raise ModuleNotFoundError("ControlNet extension not found.")
|
||||
input_image_np = np.array(input_image)
|
||||
if use_pixel_perfect:
|
||||
processor_res = pixel_perfect_resolution(input_image_np, resize_mode, target_W, target_H)
|
||||
logger.info("Generating semantic segmentation without SAM")
|
||||
if annotator_name == "seg_ufade20k":
|
||||
original_semantic = uniformer(input_image_np, processor_res)
|
||||
else:
|
||||
dataset = annotator_name.split('_')[-1][2:]
|
||||
original_semantic = oneformers[dataset](input_image_np, processor_res)
|
||||
logger.info("Generating semantic segmentation with SAM")
|
||||
sam_semantic = self.strengthen_semantic_seg(np.array(original_semantic), input_image_np)
|
||||
output_gallery = [original_semantic, sam_semantic, blend_image_and_seg(input_image, original_semantic), blend_image_and_seg(input_image, sam_semantic)]
|
||||
return output_gallery, "Done. Left is segmentation before SAM, right is segmentation after SAM."
|
||||
else:
|
||||
return self.random_segmentation(input_image)
|
||||
|
||||
|
||||
def categorical_mask_image(self, annotator_name: str, processor_res: int, category_input: List[int], input_image: Image.Image,
|
||||
use_pixel_perfect: bool, resize_mode: int, target_W: int, target_H: int) -> Tuple[np.ndarray, Image.Image]:
|
||||
"""Single image categorical mask.
|
||||
|
||||
Args:
|
||||
annotator_name (str): annotator name. Should be one of "seg_ufade20k"|"seg_ofade20k"|"seg_ofcoco".
|
||||
processor_res (int): processor resolution. Support 64-2048.
|
||||
category_input (List[int]): category input.
|
||||
input_image (Image.Image): input image.
|
||||
use_pixel_perfect (bool): whether to use pixel perfect written by lllyasviel.
|
||||
resize_mode (int): resize mode for pixel perfect, should be 0|1|2.
|
||||
target_W (int): target width for pixel perfect.
|
||||
target_H (int): target height for pixel perfect.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: ControlNet not installed.
|
||||
AssertionError: Illegal class id.
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, Image.Image]: mask in resized shape and resized input image.
|
||||
"""
|
||||
assert input_image is not None, "No input image."
|
||||
try:
|
||||
from scripts.processor import uniformer, oneformer_coco, oneformer_ade20k
|
||||
from scripts.external_code import pixel_perfect_resolution
|
||||
oneformers = {
|
||||
"ade20k": oneformer_ade20k,
|
||||
"coco": oneformer_coco
|
||||
}
|
||||
except:
|
||||
raise ModuleNotFoundError("ControlNet extension not found.")
|
||||
filter_classes = category_input
|
||||
assert len(filter_classes) > 0, "No class selected."
|
||||
try:
|
||||
filter_classes = [int(i) for i in filter_classes]
|
||||
except:
|
||||
raise AssertionError("Illegal class id. You may have input some string.")
|
||||
input_image_np = np.array(input_image)
|
||||
if use_pixel_perfect:
|
||||
processor_res = pixel_perfect_resolution(input_image_np, resize_mode, target_W, target_H)
|
||||
crop_input_image_copy = copy.deepcopy(input_image)
|
||||
logger.info(f"Generating categories with processor {annotator_name}")
|
||||
if annotator_name == "seg_ufade20k":
|
||||
original_semantic = uniformer(input_image_np, processor_res)
|
||||
else:
|
||||
dataset = annotator_name.split('_')[-1][2:]
|
||||
original_semantic = oneformers[dataset](input_image_np, processor_res)
|
||||
sam_semantic = self.strengthen_semantic_seg(np.array(original_semantic), input_image_np)
|
||||
mask = np.zeros(sam_semantic.shape, dtype=np.bool_)
|
||||
from sam_utils.config import SEMANTIC_CATEGORIES
|
||||
for i in filter_classes:
|
||||
mask[np.equal(sam_semantic, SEMANTIC_CATEGORIES[annotator_name][i])] = True
|
||||
return mask, crop_input_image_copy
|
||||
|
||||
|
||||
def register(
|
||||
self,
|
||||
sam_model_name: str,
|
||||
points_per_side: Optional[int] = 32,
|
||||
points_per_batch: int = 64,
|
||||
pred_iou_thresh: float = 0.88,
|
||||
stability_score_thresh: float = 0.95,
|
||||
stability_score_offset: float = 1.0,
|
||||
box_nms_thresh: float = 0.7,
|
||||
crop_n_layers: int = 0,
|
||||
crop_nms_thresh: float = 0.7,
|
||||
crop_overlap_ratio: float = 512 / 1500,
|
||||
crop_n_points_downscale_factor: int = 1,
|
||||
min_mask_region_area: int = 0,
|
||||
output_mode: str = "binary_mask",
|
||||
fastsam_conf: float = 0.4,
|
||||
fastsam_iou: float = 0.9) -> None:
|
||||
"""Register AutoSAM module."""
|
||||
self.sam.load_sam_model(sam_model_name)
|
||||
assert type(self.sam.sam_model) in [FastSAM, Sam], f"{sam_model_name} does not support auto segmentation."
|
||||
if type(self.sam.sam_model) == FastSAM:
|
||||
self.fastsam_conf = fastsam_conf
|
||||
self.fastsam_iou = fastsam_iou
|
||||
self.auto_sam = self.sam.sam_model
|
||||
else:
|
||||
self.auto_sam = SamAutomaticMaskGeneratorHQ(
|
||||
self.sam.sam_model, points_per_side, points_per_batch, pred_iou_thresh,
|
||||
stability_score_thresh, stability_score_offset, box_nms_thresh,
|
||||
crop_n_layers, crop_nms_thresh, crop_overlap_ratio, crop_n_points_downscale_factor, None,
|
||||
min_mask_region_area, output_mode)
|
||||
|
|
@ -1,22 +1,24 @@
|
|||
from typing import Tuple
|
||||
from typing import List
|
||||
import os
|
||||
import gc
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
from modules.devices import device, torch_gc
|
||||
from scripts.sam_log import logger
|
||||
from modules.devices import device
|
||||
from sam_utils.logger import logger
|
||||
|
||||
|
||||
# TODO: support YOLO models
|
||||
class Detection:
|
||||
"""Detection related process.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize detection related process.
|
||||
"""
|
||||
self.dino_model = None
|
||||
self.dino_model_type = ""
|
||||
from scripts.sam_state import sam_extension_dir
|
||||
self.dino_model_dir = os.path.join(sam_extension_dir, "models/grounding-dino")
|
||||
self.dino_model_list = ["GroundingDINO_SwinT_OGC (694MB)", "GroundingDINO_SwinB (938MB)"]
|
||||
self.dino_model_info = {
|
||||
"GroundingDINO_SwinT_OGC (694MB)": {
|
||||
"checkpoint": "groundingdino_swint_ogc.pth",
|
||||
|
|
@ -29,68 +31,26 @@ class Detection:
|
|||
"url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth"
|
||||
},
|
||||
}
|
||||
self.dino_install_issue_text = "Please permanently switch to local GroundingDINO on Settings/Segment Anything or submit an issue to https://github.com/IDEA-Research/Grounded-Segment-Anything/issues."
|
||||
|
||||
|
||||
def _install_goundingdino(self) -> bool:
|
||||
if shared.opts.data.get("sam_use_local_groundingdino", False):
|
||||
logger.info("Using local groundingdino.")
|
||||
return False
|
||||
def _load_dino_model(self, dino_checkpoint: str, use_pip_dino: bool) -> None:
|
||||
"""Load GroundignDINO model to device.
|
||||
|
||||
def verify_dll(install_local=True):
|
||||
try:
|
||||
from groundingdino import _C
|
||||
logger.info("GroundingDINO dynamic library have been successfully built.")
|
||||
return True
|
||||
except Exception:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
def run_pip_uninstall(command, desc=None):
|
||||
from launch import python, run
|
||||
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
||||
return run(f'"{python}" -m pip uninstall -y {command}', desc=f"Uninstalling {desc}", errdesc=f"Couldn't uninstall {desc}", live=default_command_live)
|
||||
if install_local:
|
||||
logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and fall back to local GroundingDINO this time. {self.dino_install_issue_text}")
|
||||
run_pip_uninstall(
|
||||
f"groundingdino",
|
||||
f"sd-webui-segment-anything requirement: groundingdino")
|
||||
else:
|
||||
logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and re-try installing from GitHub source code. {self.dino_install_issue_text}")
|
||||
run_pip_uninstall(
|
||||
f"uninstall groundingdino",
|
||||
f"sd-webui-segment-anything requirement: groundingdino")
|
||||
return False
|
||||
|
||||
import launch
|
||||
if launch.is_installed("groundingdino"):
|
||||
logger.info("Found GroundingDINO in pip. Verifying if dynamic library build success.")
|
||||
if verify_dll(install_local=False):
|
||||
return True
|
||||
try:
|
||||
launch.run_pip(
|
||||
f"install git+https://github.com/IDEA-Research/GroundingDINO",
|
||||
f"sd-webui-segment-anything requirement: groundingdino")
|
||||
logger.info("GroundingDINO install success. Verifying if dynamic library build success.")
|
||||
return verify_dll()
|
||||
except Exception:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.warn(f"GroundingDINO install failed. Will fall back to local groundingdino this time. {self.dino_install_issue_text}")
|
||||
return False
|
||||
|
||||
|
||||
def _load_dino_model(self, dino_checkpoint: str, dino_install_success: bool) -> torch.nn.Module:
|
||||
Args:
|
||||
dino_checkpoint (str): GroundingDINO checkpoint name.
|
||||
use_pip_dino (bool): If True, use pip installed GroundingDINO. If False, use local GroundingDINO.
|
||||
"""
|
||||
logger.info(f"Initializing GroundingDINO {dino_checkpoint}")
|
||||
if self.dino_model is None or dino_checkpoint != self.dino_model_type:
|
||||
self.clear()
|
||||
if dino_install_success:
|
||||
if use_pip_dino:
|
||||
from groundingdino.models import build_model
|
||||
from groundingdino.util.slconfig import SLConfig
|
||||
from groundingdino.util.utils import clean_state_dict
|
||||
else:
|
||||
from local_groundingdino.models import build_model
|
||||
from local_groundingdino.util.slconfig import SLConfig
|
||||
from local_groundingdino.util.utils import clean_state_dict
|
||||
from thirdparty.groundingdino.models import build_model
|
||||
from thirdparty.groundingdino.util.slconfig import SLConfig
|
||||
from thirdparty.groundingdino.util.utils import clean_state_dict
|
||||
args = SLConfig.fromfile(self.dino_model_info[dino_checkpoint]["config"])
|
||||
dino = build_model(args)
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
|
|
@ -103,11 +63,20 @@ class Detection:
|
|||
self.dino_model.to(device=device)
|
||||
|
||||
|
||||
def _load_dino_image(self, image_pil: Image.Image, dino_install_success: bool) -> torch.Tensor:
|
||||
if dino_install_success:
|
||||
def _load_dino_image(self, image_pil: Image.Image, use_pip_dino: bool) -> torch.Tensor:
|
||||
"""Transform image to make the image applicable to GroundingDINO.
|
||||
|
||||
Args:
|
||||
image_pil (Image.Image): Input image in PIL format.
|
||||
use_pip_dino (bool): If True, use pip installed GroundingDINO. If False, use local GroundingDINO.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed image in torch.Tensor format.
|
||||
"""
|
||||
if use_pip_dino:
|
||||
import groundingdino.datasets.transforms as T
|
||||
else:
|
||||
from local_groundingdino.datasets import transforms as T
|
||||
from thirdparty.groundingdino.datasets import transforms as T
|
||||
transform = T.Compose(
|
||||
[
|
||||
T.RandomResize([800], max_size=1333),
|
||||
|
|
@ -119,7 +88,17 @@ class Detection:
|
|||
return image
|
||||
|
||||
|
||||
def _get_grounding_output(self, image: torch.Tensor, caption: str, box_threshold: float):
|
||||
def _get_grounding_output(self, image: torch.Tensor, caption: str, box_threshold: float) -> torch.Tensor:
|
||||
"""Inference GroundingDINO model.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): transformed input image.
|
||||
caption (str): string caption.
|
||||
box_threshold (float): bbox threshold.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: generated bounding boxes.
|
||||
"""
|
||||
caption = caption.lower()
|
||||
caption = caption.strip()
|
||||
if not caption.endswith("."):
|
||||
|
|
@ -128,7 +107,7 @@ class Detection:
|
|||
with torch.no_grad():
|
||||
outputs = self.dino_model(image[None], captions=[caption])
|
||||
if shared.cmd_opts.lowvram:
|
||||
self.dino_model.cpu()
|
||||
self.unload_model()
|
||||
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
|
||||
boxes = outputs["pred_boxes"][0] # (nq, 4)
|
||||
|
||||
|
|
@ -142,12 +121,23 @@ class Detection:
|
|||
return boxes_filt.cpu()
|
||||
|
||||
|
||||
def dino_predict(self, input_image: Image.Image, dino_model_name: str, text_prompt: str, box_threshold: float) -> Tuple[torch.Tensor, bool]:
|
||||
install_success = self._install_goundingdino()
|
||||
def dino_predict(self, input_image: Image.Image, dino_model_name: str, text_prompt: str, box_threshold: float) -> List[List[float]]:
|
||||
"""Exposed API for GroundingDINO inference.
|
||||
|
||||
Args:
|
||||
input_image (Image.Image): input image.
|
||||
dino_model_name (str): GroundingDINO model name.
|
||||
text_prompt (str): string prompt.
|
||||
box_threshold (float): bbox threshold.
|
||||
|
||||
Returns:
|
||||
List[List[float]]: generated N * xyxy bounding boxes.
|
||||
"""
|
||||
from sam_utils.util import install_goundingdino
|
||||
install_success = install_goundingdino()
|
||||
logger.info("Running GroundingDINO Inference")
|
||||
dino_image = self._load_dino_image(input_image.convert("RGB"), install_success)
|
||||
self._load_dino_model(dino_model_name, install_success)
|
||||
using_groundingdino = install_success or shared.opts.data.get("sam_use_local_groundingdino", False)
|
||||
|
||||
boxes_filt = self._get_grounding_output(
|
||||
dino_image, text_prompt, box_threshold
|
||||
|
|
@ -158,16 +148,74 @@ class Detection:
|
|||
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
||||
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
||||
boxes_filt[i][2:] += boxes_filt[i][:2]
|
||||
gc.collect()
|
||||
torch_gc()
|
||||
return boxes_filt, using_groundingdino
|
||||
return boxes_filt.tolist()
|
||||
|
||||
|
||||
def check_yolo_availability(self) -> List[str]:
|
||||
"""Check if YOLO models are available. Do not check if YOLO not enabled.
|
||||
|
||||
Returns:
|
||||
List[str]: available YOLO models.
|
||||
"""
|
||||
if shared.opts.data.get("sam_use_local_yolo", False):
|
||||
from modules.paths import models_path
|
||||
sd_yolo_model_dir = os.path.join(models_path, "ultralytics")
|
||||
return [name for name in os.listdir(sd_yolo_model_dir) if (".pth" in name or ".pt" in name)]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def yolo_predict(self, input_image: Image.Image, yolo_model_name: str, conf=0.4) -> List[List[float]]:
|
||||
"""Run detection inference with models based on YOLO.
|
||||
|
||||
Args:
|
||||
input_image (np.ndarray): input image, expect shape HW3.
|
||||
conf (float, optional): object confidence threshold. Defaults to 0.4.
|
||||
|
||||
Raises:
|
||||
RuntimeError: not getting any bbox. Might be caused by high conf or non-detection/segmentation model.
|
||||
|
||||
Returns:
|
||||
np.ndarray: generated N * xyxy bounding boxes.
|
||||
"""
|
||||
from ultralytics import YOLO
|
||||
assert shared.opts.data.get("sam_use_yolo_models", False), "YOLO models are not enabled. Please enable in settings/Segment Anything."
|
||||
logger.info("Loading YOLO model.")
|
||||
from modules.paths import models_path
|
||||
sd_yolo_model_dir = os.path.join(models_path, "ultralytics")
|
||||
self.dino_model = YOLO(os.path.join(sd_yolo_model_dir, yolo_model_name)).to(device)
|
||||
self.dino_model_type = yolo_model_name
|
||||
logger.info("Running YOLO inference.")
|
||||
pred = self.dino_model(input_image, conf=conf)
|
||||
bboxes = pred[0].boxes.xyxy.cpu().numpy()
|
||||
if bboxes.size == 0:
|
||||
error_msg = "You are not getting any bbox. There are 2 possible reasons. "\
|
||||
"1. You set up a high conf which means that you should lower the conf. "\
|
||||
"2. You are using a non-detection model which means that you should check your model type."
|
||||
raise RuntimeError(error_msg)
|
||||
else:
|
||||
return bboxes.tolist()
|
||||
|
||||
|
||||
def __call__(self, input_image: Image.Image, model_name: str, text_prompt: str, box_threshold: float, conf=0.4) -> List[List[float]]:
|
||||
"""Exposed API for detection inference."""
|
||||
if model_name in self.dino_model_info.keys():
|
||||
return self.dino_predict(input_image, model_name, text_prompt, box_threshold)
|
||||
elif model_name in self.check_yolo_availability():
|
||||
return self.yolo_predict(input_image, model_name, conf)
|
||||
else:
|
||||
raise ValueError(f"Detection model {model_name} not found.")
|
||||
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear detection model from any memory.
|
||||
"""
|
||||
del self.dino_model
|
||||
self.dino_model = None
|
||||
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload detection model from GPU to CPU.
|
||||
"""
|
||||
if self.dino_model is not None:
|
||||
self.dino_model.cpu()
|
||||
|
|
|
|||
|
|
@ -1,23 +1,27 @@
|
|||
from typing import List
|
||||
import gc
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
from thirdparty.fastsam import FastSAM, FastSAMPrompt
|
||||
from modules import shared
|
||||
from modules.safe import unsafe_torch_load, load
|
||||
from modules.devices import get_device_for, cpu
|
||||
from modules.devices import get_device_for, cpu, torch_gc
|
||||
from modules.paths import models_path
|
||||
from scripts.sam_state import sam_extension_dir
|
||||
from scripts.sam_log import logger
|
||||
from sam_utils.logger import logger
|
||||
from sam_utils.util import ModelInfo
|
||||
from sam_hq.build_sam_hq import sam_model_registry
|
||||
from sam_hq.predictor import SamPredictorHQ
|
||||
from mam.m2m import SamM2M
|
||||
from thirdparty.sam_hq.build_sam_hq import sam_model_registry
|
||||
from thirdparty.sam_hq.predictor import SamPredictorHQ
|
||||
from thirdparty.mam.m2m import SamM2M
|
||||
|
||||
|
||||
class Segmentation:
|
||||
"""Segmentation related process."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize segmentation related process."""
|
||||
self.sam_model_info = {
|
||||
"sam_vit_h_4b8939.pth" : ModelInfo("SAM", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "Meta", "2.56GB"),
|
||||
"sam_vit_l_0b3195.pth" : ModelInfo("SAM", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "Meta", "1.25GB"),
|
||||
|
|
@ -37,7 +41,11 @@ class Segmentation:
|
|||
|
||||
|
||||
def check_model_availability(self) -> List[str]:
|
||||
# retrieve all models in all the model directories
|
||||
"""retrieve all models in all the model directories
|
||||
|
||||
Returns:
|
||||
List[str]: Model information displayed on the UI
|
||||
"""
|
||||
user_sam_model_dir = shared.opts.data.get("sam_model_path", "")
|
||||
sd_sam_model_dir = os.path.join(models_path, "sam")
|
||||
scripts_sam_model_dir = os.path.join(sam_extension_dir, "models/sam")
|
||||
|
|
@ -45,7 +53,7 @@ class Segmentation:
|
|||
sam_model_dirs = [sd_sam_model_dir, scripts_sam_model_dir]
|
||||
if user_sam_model_dir != "":
|
||||
sam_model_dirs.append(user_sam_model_dir)
|
||||
if shared.opts.data.get("use_yolo_models", False):
|
||||
if shared.opts.data.get("sam_use_yolo_models", False):
|
||||
sam_model_dirs.append(sd_yolo_model_dir)
|
||||
for dir in sam_model_dirs:
|
||||
if os.path.isdir(dir):
|
||||
|
|
@ -54,19 +62,27 @@ class Segmentation:
|
|||
for name in sam_model_names:
|
||||
if name in self.sam_model_info.keys():
|
||||
self.sam_model_info[name].local_path(os.path.join(dir, name))
|
||||
elif shared.opts.data.get("use_yolo_models", False):
|
||||
elif shared.opts.data.get("sam_use_yolo_models", False):
|
||||
logger.warn(f"Model {name} not found in support list, default to use YOLO as initializer")
|
||||
self.sam_model_info[name] = ModelInfo("YOLO", os.path.join(dir, name), "?", "?", "downloaded")
|
||||
return [val.get_info(key) for key, val in self.sam_model_info.items()]
|
||||
|
||||
|
||||
def load_sam_model(self, sam_checkpoint_name: str) -> None:
|
||||
"""Load segmentation model.
|
||||
|
||||
Args:
|
||||
sam_checkpoint_name (str): The model filename. Do not change.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Model file not found in either support list or local model directory.
|
||||
RuntimeError: Cannot automatically download model from remote server.
|
||||
"""
|
||||
sam_checkpoint_name = sam_checkpoint_name.split(" ")[0]
|
||||
if self.sam_model is None or self.sam_model_name != sam_checkpoint_name:
|
||||
if sam_checkpoint_name not in self.sam_model_info.keys():
|
||||
error_msg = f"{sam_checkpoint_name} not found and cannot be auto downloaded"
|
||||
logger.error(f"{error_msg}")
|
||||
raise Exception(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
if "http" in self.sam_model_info[sam_checkpoint_name].url:
|
||||
sam_url = self.sam_model_info[sam_checkpoint_name].url
|
||||
user_dir = shared.opts.data.get("sam_model_path", "")
|
||||
|
|
@ -89,158 +105,343 @@ class Segmentation:
|
|||
logger.info(f"Loading SAM model from {model_path}")
|
||||
self.sam_model = sam_model_registry[sam_checkpoint_name](checkpoint=model_path)
|
||||
self.sam_model_wrapper = SamPredictorHQ(self.sam_model, 'HQ' in model_type)
|
||||
else:
|
||||
elif "SAM" in model_type:
|
||||
logger.info(f"Loading FastSAM model from {model_path}")
|
||||
self.sam_model = FastSAM(model_path)
|
||||
self.sam_model_wrapper = self.sam_model
|
||||
elif shared.opts.data.get("sam_use_yolo_models", False):
|
||||
logger.info(f"Loading YOLO model from {model_path}")
|
||||
self.sam_model = YOLO(model_path)
|
||||
self.sam_model_wrapper = self.sam_model
|
||||
else:
|
||||
error_msg = f"Unsupported model type {model_type}"
|
||||
raise RuntimeError(error_msg)
|
||||
self.sam_model_name = sam_checkpoint_name
|
||||
torch.load = load
|
||||
self.sam_model.to(self.sam_device)
|
||||
|
||||
|
||||
def change_device(self, use_cpu: bool) -> None:
|
||||
"""Change the device of the segmentation model.
|
||||
|
||||
Args:
|
||||
use_cpu (bool): Whether to use CPU for SAM inference.
|
||||
"""
|
||||
self.sam_device = cpu if use_cpu else get_device_for("sam")
|
||||
|
||||
|
||||
def __call__(self,
|
||||
input_image: np.ndarray,
|
||||
point_coords: List[List[int]]=None,
|
||||
point_labels: List[List[int]]=None,
|
||||
boxes: torch.Tensor=None,
|
||||
multimask_output=True,
|
||||
use_mam=False):
|
||||
pass
|
||||
|
||||
|
||||
def sam_predict(self,
|
||||
input_image: np.ndarray,
|
||||
positive_points: List[List[int]]=None,
|
||||
negative_points: List[List[int]]=None,
|
||||
boxes_coords: torch.Tensor=None,
|
||||
boxes_labels: List[bool]=None,
|
||||
multimask_output=True,
|
||||
merge_point_and_box=True,
|
||||
point_with_box=False,
|
||||
use_mam=False,
|
||||
use_mam_for_each_infer=False,
|
||||
mam_guidance_mode: str="mask") -> np.ndarray:
|
||||
input_image: np.ndarray,
|
||||
positive_points: List[List[int]]=None,
|
||||
negative_points: List[List[int]]=None,
|
||||
positive_bbox: List[List[float]]=None,
|
||||
negative_bbox: List[List[float]]=None,
|
||||
merge_positive=True,
|
||||
multimask_output=True,
|
||||
point_with_box=False,
|
||||
use_mam=False,
|
||||
mam_guidance_mode: str="mask") -> np.ndarray:
|
||||
"""Run segmentation inference with models based on segment anything.
|
||||
|
||||
Args:
|
||||
input_image (np.ndarray): input image, expect shape HW3.
|
||||
positive_points (List[List[int]], optional): positive point prompts. Defaults to None.
|
||||
negative_points (List[List[int]], optional): negative point prompts. Defaults to None.
|
||||
boxes_coords (torch.Tensor, optional): bbox inputs, expect shape xyxy. Defaults to None.
|
||||
boxes_labels (List[bool], optional): bbox labels, support positive & negative bboxes. Defaults to None.
|
||||
positive_points (List[List[int]], optional): positive point prompts, expect N * xy. Defaults to None.
|
||||
negative_points (List[List[int]], optional): negative point prompts, expect N * xy. Defaults to None.
|
||||
positive_bbox (List[List[float]], optional): positive bbox prompts, expect N * xyxy. Defaults to None.
|
||||
negative_bbox (List[List[float]], optional): negative bbox prompts, expect N * xyxy. Defaults to None.
|
||||
merge_positive (bool, optional): OR all positive masks. Defaults to True. Valid only if point_with_box is False.
|
||||
multimask_output (bool, optional): output 3 masks or not. Defaults to True.
|
||||
merge_point_and_box (bool, optional): if True, output point masks || bbox masks; otherwise, output point masks && bbox masks. Valid only if point_with_box is False. Defaults to True.
|
||||
point_with_box (bool, optional): always send bboxes and points to the model at the same time. Defaults to False.
|
||||
use_mam (bool, optional): use Matting-Anything. Defaults to False.
|
||||
use_mam_for_each_infer (bool, optional): use Matting-Anything for each SAM inference. Valid only if use_mam is True. Defaults to False.
|
||||
mam_guidance_mode (str, optional): guidance model for Matting-Anything. Expect "mask" or "bbox". Valid only if use_mam is True. Defaults to "mask".
|
||||
mam_guidance_mode (str, optional): guidance model for Matting-Anything. Expect "mask" or "bbox". Defaults to "mask". Valid only if use_mam is True.
|
||||
|
||||
Returns:
|
||||
np.ndarray: mask output, expect shape 11HW or 31HW.
|
||||
"""
|
||||
masks_for_points, masks_for_boxes, masks = None, None, None
|
||||
has_points = positive_points is not None or negative_points is not None
|
||||
has_bbox = positive_bbox is not None or negative_bbox is not None
|
||||
assert has_points or has_bbox, "No input provided. Please provide at least one point or one bbox."
|
||||
assert type(self.sam_model_wrapper) == SamPredictorHQ, "Incorrect SAM model wrapper. Expect SamPredictorHQ here."
|
||||
mask_shape = ((3, 1) if multimask_output else (1, 1)) + input_image.shape[:2]
|
||||
self.sam_model_wrapper.set_image(input_image)
|
||||
|
||||
# If use Matting-Anything, load mam model.
|
||||
if use_mam:
|
||||
self.sam_m2m.load_m2m() # TODO: raise exception if network problem
|
||||
# When always send bboxes and points to the model at the same time.
|
||||
if point_coords is not None and point_labels is not None and boxes_coords is not None and point_with_box:
|
||||
try:
|
||||
self.sam_m2m.load_m2m()
|
||||
except:
|
||||
use_mam = False
|
||||
|
||||
# Matting-Anything inference for each SAM inference.
|
||||
def _mam_infer(mask: np.ndarray, low_res_mask: np.ndarray) -> np.ndarray:
|
||||
low_res_mask_logits = low_res_mask > self.sam_model_wrapper.model.mask_threshold
|
||||
if use_mam:
|
||||
mask = self.sam_m2m.forward(
|
||||
self.sam_model_wrapper.features, torch.tensor(input_image), low_res_mask_logits, mask,
|
||||
self.sam_model_wrapper.original_size, self.sam_model_wrapper.input_size, mam_guidance_mode)
|
||||
return mask
|
||||
|
||||
|
||||
# When always send bboxes and points to SAM at the same time.
|
||||
if has_points and has_bbox and point_with_box:
|
||||
logger.info(f"SAM {self.sam_model_name} inference with "
|
||||
f"{len(positive_points)} positive points, {len(negative_points)} negative points, "
|
||||
f"{len(positive_bbox)} positive bboxes, {len(negative_bbox)} negative bboxes. "
|
||||
f"For each bbox, all point prompts are affective. Masks for each bbox will be merged.")
|
||||
point_coords = np.array(positive_points + negative_points)
|
||||
point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
|
||||
point_labels_neg = np.array([0] * len(positive_points) + [1] * len(negative_points))
|
||||
positive_masks, negative_masks, positive_low_res_masks, negative_low_res_masks = [], [], [], []
|
||||
# Inference for each positive bbox.
|
||||
for box in boxes_coords[boxes_labels]:
|
||||
mask, _, low_res_mask = self.sam_model_wrapper.predict(
|
||||
point_coords=point_coords,
|
||||
point_labels=point_labels,
|
||||
box=box.numpy(),
|
||||
multimask_output=multimask_output)
|
||||
mask = mask[:, None, ...]
|
||||
low_res_mask = low_res_mask[:, None, ...]
|
||||
if use_mam:
|
||||
low_res_mask_logits = low_res_mask > self.sam_model_wrapper.model.mask_threshold
|
||||
if use_mam_for_each_infer:
|
||||
mask = self.sam_m2m.forward(
|
||||
self.sam_model_wrapper.features, torch.tensor(input_image), low_res_mask_logits, mask,
|
||||
self.sam_model_wrapper.original_size, self.sam_model_wrapper.input_size, mam_guidance_mode)
|
||||
else:
|
||||
positive_low_res_masks.append(low_res_mask_logits)
|
||||
positive_masks.append(mask)
|
||||
positive_masks = np.logical_or(np.stack(positive_masks, 0))
|
||||
# Inference for each negative bbox.
|
||||
for box in boxes_coords[[not i for i in boxes_labels]]:
|
||||
mask, _, low_res_mask = self.sam_model_wrapper.predict(
|
||||
point_coords=point_coords,
|
||||
point_labels=point_labels_neg,
|
||||
box=box.numpy(),
|
||||
multimask_output=multimask_output)
|
||||
mask = mask[:, None, ...]
|
||||
low_res_mask = low_res_mask[:, None, ...]
|
||||
if use_mam:
|
||||
low_res_mask_logits = low_res_mask > self.sam_model_wrapper.model.mask_threshold
|
||||
if use_mam_for_each_infer:
|
||||
mask = self.sam_m2m.forward(
|
||||
self.sam_model_wrapper.features, torch.tensor(input_image), low_res_mask_logits, mask,
|
||||
self.sam_model_wrapper.original_size, self.sam_model_wrapper.input_size, mam_guidance_mode)
|
||||
else:
|
||||
negative_low_res_masks.append(low_res_mask_logits)
|
||||
negative_masks.append(mask)
|
||||
negative_masks = np.logical_or(np.stack(negative_masks, 0))
|
||||
masks = np.logical_and(positive_masks, ~negative_masks)
|
||||
# Matting-Anything inference if not for each inference.
|
||||
if use_mam and not use_mam_for_each_infer:
|
||||
positive_low_res_masks = np.logical_or(np.stack(positive_low_res_masks, 0))
|
||||
negative_low_res_masks = np.logical_or(np.stack(negative_low_res_masks, 0))
|
||||
low_res_masks = np.logical_and(positive_low_res_masks, ~negative_low_res_masks)
|
||||
masks = self.sam_m2m.forward(
|
||||
self.sam_model_wrapper.features, torch.tensor(input_image), low_res_masks, masks,
|
||||
self.sam_model_wrapper.original_size, self.sam_model_wrapper.input_size, mam_guidance_mode)
|
||||
return masks
|
||||
|
||||
# When separate bbox inference from point inference.
|
||||
if point_coords is not None and point_labels is not None:
|
||||
def _box_infer(_bbox, _point_labels):
|
||||
_masks = []
|
||||
for box in _bbox:
|
||||
mask, _, low_res_mask = self.sam_model_wrapper.predict(
|
||||
point_coords=point_coords,
|
||||
point_labels=_point_labels,
|
||||
box=np.array(box),
|
||||
multimask_output=multimask_output)
|
||||
mask = mask[:, None, ...]
|
||||
low_res_mask = low_res_mask[:, None, ...]
|
||||
_masks.append(_mam_infer(mask, low_res_mask))
|
||||
return np.logical_or.reduce(_masks)
|
||||
|
||||
mask_bbox_positive = _box_infer(positive_bbox, point_labels) if positive_bbox is not None else np.ones(shape=mask_shape, dtype=np.bool_)
|
||||
mask_bbox_negative = _box_infer(negative_bbox, point_labels_neg) if negative_bbox is not None else np.zeros(shape=mask_shape, dtype=np.bool_)
|
||||
|
||||
return mask_bbox_positive & ~mask_bbox_negative
|
||||
|
||||
# When separate bbox inference from point inference.
|
||||
if has_points:
|
||||
logger.info(f"SAM {self.sam_model_name} inference with "
|
||||
f"{len(positive_points)} positive points, {len(negative_points)} negative points")
|
||||
point_coords = np.array(positive_points + negative_points)
|
||||
point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
|
||||
masks_for_points, _, _ = self.sam_model_wrapper.predict(
|
||||
mask_points_positive, _, low_res_masks_points_positive = self.sam_model_wrapper.predict(
|
||||
point_coords=point_coords,
|
||||
point_labels=point_labels,
|
||||
box=None,
|
||||
multimask_output=multimask_output)
|
||||
masks_for_points = masks_for_points[:, None, ...]
|
||||
# TODO: m2m
|
||||
if boxes_coords is not None:
|
||||
transformed_boxes = self.sam_model_wrapper.transform.apply_boxes_torch(boxes_coords, input_image.shape[:2])
|
||||
masks_for_boxes, _, _ = self.sam_model_wrapper.predict_torch(
|
||||
mask_points_positive = mask_points_positive[:, None, ...]
|
||||
low_res_masks_points_positive = low_res_masks_points_positive[:, None, ...]
|
||||
mask_points_positive = _mam_infer(mask_points_positive, low_res_masks_points_positive)
|
||||
else:
|
||||
mask_points_positive = np.ones(shape=mask_shape, dtype=np.bool_)
|
||||
|
||||
def _box_infer(_bbox, _character):
|
||||
logger.info(f"SAM {self.sam_model_name} inference with {len(positive_bbox)} {_character} bboxes")
|
||||
transformed_boxes = self.sam_model_wrapper.transform.apply_boxes_torch(torch.tensor(_bbox), input_image.shape[:2])
|
||||
mask, _, low_res_mask = self.sam_model_wrapper.predict_torch(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
boxes=transformed_boxes.to(self.sam_device),
|
||||
multimask_output=multimask_output)
|
||||
masks_for_boxes = masks_for_boxes.permute(1, 0, 2, 3).cpu().numpy()
|
||||
# TODO: m2m
|
||||
if masks_for_boxes is not None and masks_for_points is not None:
|
||||
if merge_point_and_box:
|
||||
masks = np.logical_or(masks_for_points, masks_for_boxes)
|
||||
else:
|
||||
masks = np.logical_and(masks_for_points, masks_for_boxes)
|
||||
elif masks_for_boxes is not None:
|
||||
masks = masks_for_boxes
|
||||
elif masks_for_points is not None:
|
||||
masks = masks_for_points
|
||||
# TODO: m2m
|
||||
return masks
|
||||
mask = mask.permute(1, 0, 2, 3).cpu().numpy()
|
||||
low_res_mask = low_res_mask.permute(1, 0, 2, 3).cpu().numpy()
|
||||
return _mam_infer(mask, low_res_mask)
|
||||
mask_bbox_positive = _box_infer(positive_bbox, "positive") if positive_bbox is not None else np.ones(shape=mask_shape, dtype=np.bool_)
|
||||
mask_bbox_negative = _box_infer(negative_bbox, "negative") if negative_bbox is not None else np.zeros(shape=mask_shape, dtype=np.bool_)
|
||||
|
||||
if merge_positive:
|
||||
return (mask_points_positive | mask_bbox_positive) & ~mask_bbox_negative
|
||||
else:
|
||||
return (mask_points_positive & mask_bbox_positive) & ~mask_bbox_negative
|
||||
|
||||
|
||||
def yolo_predict():
|
||||
pass
|
||||
def fastsam_predict(self,
|
||||
input_image: np.ndarray,
|
||||
positive_points: List[List[int]]=None,
|
||||
negative_points: List[List[int]]=None,
|
||||
positive_bbox: List[List[float]]=None,
|
||||
negative_bbox: List[List[float]]=None,
|
||||
positive_text: str="",
|
||||
negative_text: str="",
|
||||
merge_positive=True,
|
||||
merge_negative=True,
|
||||
conf=0.4, iou=0.9,) -> np.ndarray:
|
||||
"""Run segmentation inference with models based on FastSAM. (This is a special kind of YOLO model)
|
||||
|
||||
Args:
|
||||
input_image (np.ndarray): input image, expect shape HW3.
|
||||
positive_points (List[List[int]], optional): positive point prompts, expect N * xy Defaults to None.
|
||||
negative_points (List[List[int]], optional): negative point prompts, expect N * xy Defaults to None.
|
||||
positive_bbox (List[List[float]], optional): positive bbox prompts, expect N * xyxy. Defaults to None.
|
||||
negative_bbox (List[List[float]], optional): negative bbox prompts, expect N * xyxy. Defaults to None.
|
||||
positive_text (str, optional): positive text prompts. Defaults to "".
|
||||
negative_text (str, optional): negative text prompts. Defaults to "".
|
||||
merge_positive (bool, optional): OR all positive masks. Defaults to True.
|
||||
merge_negative (bool, optional): OR all negative masks. Defaults to True.
|
||||
conf (float, optional): object confidence threshold. Defaults to 0.4.
|
||||
iou (float, optional): iou threshold for filtering the annotations. Defaults to 0.9.
|
||||
|
||||
Returns:
|
||||
np.ndarray: mask output, expect shape 11HW. FastSAM does not support multi-mask selection.
|
||||
"""
|
||||
assert type(self.sam_model_wrapper) == FastSAM, "Incorrect SAM model wrapper. Expect FastSAM here."
|
||||
logger.info(f"Running FastSAM {self.sam_model_name} inference.")
|
||||
annotation = self.sam_model_wrapper(
|
||||
input_image, device=self.sam_device, retina_masks=True, imgsz=1024, conf=conf, iou=iou)
|
||||
has_points = positive_points is not None or negative_points is not None
|
||||
has_bbox = positive_bbox is not None or negative_bbox is not None
|
||||
has_text = positive_text != "" or negative_text != ""
|
||||
assert has_points or has_bbox or has_text, "No input provided. Please provide at least one point or one bbox or one text."
|
||||
logger.info("Post-processing FastSAM inference.")
|
||||
prompt_process = FastSAMPrompt(input_image, annotation, device=self.sam_device)
|
||||
mask_shape = (1, 1) + input_image.shape[:2]
|
||||
mask_bbox_positive = prompt_process.box_prompt(bboxes=positive_bbox) if positive_bbox is not None else np.ones(shape=mask_shape, dtype=np.bool_)
|
||||
mask_bbox_negative = prompt_process.box_prompt(bboxes=negative_bbox) if negative_bbox is not None else np.zeros(shape=mask_shape, dtype=np.bool_)
|
||||
|
||||
mask_text_positive = prompt_process.text_prompt(text=positive_text) if positive_text != "" else np.ones(shape=mask_shape, dtype=np.bool_)
|
||||
mask_text_negative = prompt_process.text_prompt(text=negative_text) if negative_text != "" else np.zeros(shape=mask_shape, dtype=np.bool_)
|
||||
|
||||
point_coords = np.array(positive_points + negative_points)
|
||||
point_labels = np.array([1] * len(positive_points) + [0] * len(negative_points))
|
||||
mask_points_positive = prompt_process.point_prompt(points=point_coords, pointlabel=point_labels) if has_points else np.ones(shape=mask_shape, dtype=np.bool_)
|
||||
|
||||
if merge_positive:
|
||||
mask_positive = mask_bbox_positive | mask_text_positive | mask_points_positive
|
||||
else:
|
||||
mask_positive = mask_bbox_positive & mask_text_positive & mask_points_positive
|
||||
if merge_negative:
|
||||
mask_negative = mask_bbox_negative | mask_text_negative
|
||||
else:
|
||||
mask_negative = mask_bbox_negative & mask_text_negative
|
||||
return mask_positive & ~mask_negative
|
||||
|
||||
|
||||
def yolo_predict(self, input_image: np.ndarray, conf=0.4) -> np.ndarray:
|
||||
"""Run segmentation inference with models based on YOLO.
|
||||
|
||||
Args:
|
||||
input_image (np.ndarray): input image, expect shape HW3.
|
||||
conf (float, optional): object confidence threshold. Defaults to 0.4.
|
||||
|
||||
Raises:
|
||||
RuntimeError: not getting any bbox. Might be caused by high conf or non-detection/segmentation model.
|
||||
|
||||
Returns:
|
||||
np.ndarray: mask output, expect shape 11HW. YOLO does not support multi-mask selection.
|
||||
"""
|
||||
assert shared.opts.data.get("sam_use_yolo_models", False), "YOLO models are not enabled. Please enable in settings/Segment Anything."
|
||||
assert type(self.sam_model_wrapper) == YOLO, "Incorrect SAM model wrapper. Expect YOLO here."
|
||||
logger.info("Running YOLO inference.")
|
||||
pred = self.sam_model_wrapper(input_image, conf=conf)
|
||||
bboxes = pred[0].boxes.xyxy.cpu().numpy()
|
||||
if bboxes.size == 0:
|
||||
error_msg = "You are not getting any bbox. There are 2 possible reasons. "\
|
||||
"1. You set up a high conf which means that you should lower the conf. "\
|
||||
"2. You are using a non-detection/segmentation model which means that you should check your model type."
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
if pred[0].masks is None:
|
||||
logger.warn("You are not using a segmentation model. Will use bbox to create masks.")
|
||||
masks = []
|
||||
for bbox in bboxes:
|
||||
mask_shape = (1, 1) + input_image.shape[:2]
|
||||
mask = np.zeros(mask_shape, dtype=bool)
|
||||
x1, y1, x2, y2 = bbox
|
||||
mask[:, :, y1:y2+1, x1:x2+1] = True
|
||||
return np.logical_or.reduce(masks, axis=0)
|
||||
else:
|
||||
return np.logical_or.reduce(pred[0].masks.data, axis=0)
|
||||
|
||||
|
||||
def clear(self):
|
||||
"""Clear segmentation model from CPU & GPU."""
|
||||
del self.sam_model
|
||||
self.sam_model = None
|
||||
self.sam_model_name = ""
|
||||
self.sam_model_wrapper = None
|
||||
self.sam_m2m.clear()
|
||||
|
||||
|
||||
def unload_model(self):
|
||||
"""Move all segmentation models to CPU."""
|
||||
if self.sam_model is not None:
|
||||
self.sam_model.cpu()
|
||||
self.sam_m2m.unload_model()
|
||||
|
||||
|
||||
def __call__(self,
|
||||
sam_checkpoint_name: str,
|
||||
input_image: np.ndarray,
|
||||
positive_points: List[List[int]]=None,
|
||||
negative_points: List[List[int]]=None,
|
||||
positive_bbox: List[List[float]]=None,
|
||||
negative_bbox: List[List[float]]=None,
|
||||
positive_text: str="",
|
||||
negative_text: str="",
|
||||
merge_positive=True,
|
||||
merge_negative=True,
|
||||
multimask_output=True,
|
||||
point_with_box=False,
|
||||
use_mam=False,
|
||||
mam_guidance_mode: str="mask",
|
||||
conf=0.4, iou=0.9,) -> np.ndarray:
|
||||
# use_cpu: bool=False,) -> np.ndarray:
|
||||
"""Entry for segmentation inference. Load model, run inference, unload model if lowvram.
|
||||
|
||||
Args:
|
||||
sam_checkpoint_name (str): The model filename. Do not change.
|
||||
input_image (np.ndarray): input image, expect shape HW3.
|
||||
positive_points (List[List[int]], optional): positive point prompts, expect N * xy. Defaults to None. Valid for SAM & FastSAM.
|
||||
negative_points (List[List[int]], optional): negative point prompts, expect N * xy. Defaults to None. Valid for SAM & FastSAM.
|
||||
positive_bbox (List[List[float]], optional): positive bbox prompts, expect N * xyxy. Defaults to None. Valid for SAM & FastSAM.
|
||||
negative_bbox (List[List[float]], optional): negative bbox prompts, expect N * xyxy. Defaults to None. Valid for SAM & FastSAM.
|
||||
positive_text (str, optional): positive text prompts. Defaults to "". Valid for FastSAM.
|
||||
negative_text (str, optional): negative text prompts. Defaults to "". Valid for FastSAM.
|
||||
merge_positive (bool, optional): OR all positive masks. Defaults to True. Valid for SAM (point_with_box is True) & FastSAM.
|
||||
merge_negative (bool, optional): OR all negative masks. Defaults to True. Valid for FastSAM.
|
||||
multimask_output (bool, optional): output 3 masks or not. Defaults to True. Valid for SAM.
|
||||
point_with_box (bool, optional): always send bboxes and points to the model at the same time. Defaults to False. Valid for SAM.
|
||||
use_mam (bool, optional): use Matting-Anything. Defaults to False. Valid for SAM.
|
||||
mam_guidance_mode (str, optional): guidance model for Matting-Anything. Expect "mask" or "bbox". Defaults to "mask". Valid for SAM and use_mam is True.
|
||||
conf (float, optional): object confidence threshold. Defaults to 0.4. Valid for FastSAM & YOLO.
|
||||
iou (float, optional): iou threshold for filtering the annotations. Defaults to 0.9. Valid for FastSAM.
|
||||
use_cpu (bool, optional): use CPU for SAM inference. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: _description_
|
||||
"""
|
||||
# self.change_device(use_cpu)
|
||||
self.load_sam_model(sam_checkpoint_name)
|
||||
if type(self.sam_model_wrapper) == SamPredictorHQ:
|
||||
masks = self.sam_predict(
|
||||
input_image=input_image,
|
||||
positive_points=positive_points,
|
||||
negative_points=negative_points,
|
||||
positive_bbox=positive_bbox,
|
||||
negative_bbox=negative_bbox,
|
||||
merge_positive=merge_positive,
|
||||
multimask_output=multimask_output,
|
||||
point_with_box=point_with_box,
|
||||
use_mam=use_mam,
|
||||
mam_guidance_mode=mam_guidance_mode)
|
||||
elif type(self.sam_model_wrapper) == FastSAM:
|
||||
masks = self.fastsam_predict(
|
||||
input_image=input_image,
|
||||
positive_points=positive_points,
|
||||
negative_points=negative_points,
|
||||
positive_bbox=positive_bbox,
|
||||
negative_bbox=negative_bbox,
|
||||
positive_text=positive_text,
|
||||
negative_text=negative_text,
|
||||
merge_positive=merge_positive,
|
||||
merge_negative=merge_negative,
|
||||
conf=conf, iou=iou)
|
||||
else:
|
||||
masks = self.yolo_predict(
|
||||
input_image=input_image,
|
||||
conf=conf)
|
||||
if shared.cmd_opts.lowvram:
|
||||
self.unload_model()
|
||||
gc.collect()
|
||||
torch_gc()
|
||||
return masks
|
||||
|
||||
|
||||
# check device for each model type
|
||||
# how to use yolo for auto
|
||||
# category name dropdown
|
||||
# category name dropdown and dynamic ui
|
||||
# yolo model for segmentation and detection
|
||||
# zoom in and unify box+point
|
||||
# zoom in, unify box+point
|
||||
# make masks smaller
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import os
|
|||
import numpy as np
|
||||
import cv2
|
||||
import copy
|
||||
from scipy.ndimage import binary_dilation
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
|
@ -51,6 +50,7 @@ def show_masks(image_np: np.ndarray, masks: np.ndarray, alpha=0.5) -> np.ndarray
|
|||
|
||||
|
||||
def dilate_mask(mask: np.ndarray, dilation_amt: int) -> Tuple[Image.Image, np.ndarray]:
|
||||
from scipy.ndimage import binary_dilation
|
||||
x, y = np.meshgrid(np.arange(dilation_amt), np.arange(dilation_amt))
|
||||
center = dilation_amt // 2
|
||||
dilation_kernel = ((x - center)**2 + (y - center)**2 <= center**2).astype(np.uint8)
|
||||
|
|
@ -110,3 +110,95 @@ def create_mask_batch_output(
|
|||
if save_image_with_mask:
|
||||
output_blend = Image.fromarray(blended_image)
|
||||
output_blend.save(os.path.join(dest_dir, f"{filename}_{idx}_blend{ext}"))
|
||||
|
||||
|
||||
def blend_image_and_seg(image: np.ndarray, seg: np.ndarray, alpha=0.5) -> Image.Image:
|
||||
image_blend = image * (1 - alpha) + np.array(seg) * alpha
|
||||
return Image.fromarray(image_blend.astype(np.uint8))
|
||||
|
||||
|
||||
def install_pycocotools():
|
||||
# install pycocotools if needed
|
||||
from sam_utils.logger import logger
|
||||
try:
|
||||
import pycocotools.mask as maskUtils
|
||||
except:
|
||||
logger.warn("pycocotools not found, will try installing C++ based pycocotools")
|
||||
try:
|
||||
from launch import run_pip
|
||||
run_pip(f"install pycocotools", f"AutoSAM requirement: pycocotools")
|
||||
import pycocotools.mask as maskUtils
|
||||
except:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
import sys
|
||||
if sys.platform == "win32":
|
||||
logger.warn("Unable to install pycocotools, will try installing pycocotools-windows")
|
||||
try:
|
||||
run_pip("install pycocotools-windows", "AutoSAM requirement: pycocotools-windows")
|
||||
import pycocotools.mask as maskUtils
|
||||
except:
|
||||
error_msg = "Unable to install pycocotools-windows"
|
||||
logger.error(error_msg)
|
||||
traceback.print_exc()
|
||||
raise RuntimeError(error_msg)
|
||||
else:
|
||||
error_msg = "Unable to install pycocotools"
|
||||
logger.error(error_msg)
|
||||
traceback.print_exc()
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def install_goundingdino() -> bool:
|
||||
"""Automatically install GroundingDINO.
|
||||
|
||||
Returns:
|
||||
bool: False if use local GroundingDINO, True if use pip installed GroundingDINO.
|
||||
"""
|
||||
from sam_utils.logger import logger
|
||||
from modules import shared
|
||||
dino_install_issue_text = "Please permanently switch to local GroundingDINO on Settings/Segment Anything or submit an issue to https://github.com/IDEA-Research/Grounded-Segment-Anything/issues."
|
||||
if shared.opts.data.get("sam_use_local_groundingdino", False):
|
||||
logger.info("Using local groundingdino.")
|
||||
return False
|
||||
|
||||
def verify_dll(install_local=True):
|
||||
try:
|
||||
from groundingdino import _C
|
||||
logger.info("GroundingDINO dynamic library have been successfully built.")
|
||||
return True
|
||||
except Exception:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
def run_pip_uninstall(command, desc=None):
|
||||
from launch import python, run
|
||||
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
||||
return run(f'"{python}" -m pip uninstall -y {command}', desc=f"Uninstalling {desc}", errdesc=f"Couldn't uninstall {desc}", live=default_command_live)
|
||||
if install_local:
|
||||
logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and fall back to local GroundingDINO this time. {dino_install_issue_text}")
|
||||
run_pip_uninstall(
|
||||
f"groundingdino",
|
||||
f"sd-webui-segment-anything requirement: groundingdino")
|
||||
else:
|
||||
logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and re-try installing from GitHub source code. {dino_install_issue_text}")
|
||||
run_pip_uninstall(
|
||||
f"uninstall groundingdino",
|
||||
f"sd-webui-segment-anything requirement: groundingdino")
|
||||
return False
|
||||
|
||||
import launch
|
||||
if launch.is_installed("groundingdino"):
|
||||
logger.info("Found GroundingDINO in pip. Verifying if dynamic library build success.")
|
||||
if verify_dll(install_local=False):
|
||||
return True
|
||||
try:
|
||||
launch.run_pip(
|
||||
f"install git+https://github.com/IDEA-Research/GroundingDINO",
|
||||
f"sd-webui-segment-anything requirement: groundingdino")
|
||||
logger.info("GroundingDINO install success. Verifying if dynamic library build success.")
|
||||
return verify_dll()
|
||||
except Exception:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.warn(f"GroundingDINO install failed. Will fall back to local groundingdino this time. {dino_install_issue_text}")
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -1,183 +0,0 @@
|
|||
from typing import List, Tuple
|
||||
import os
|
||||
import glob
|
||||
import copy
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from sam_hq.automatic import SamAutomaticMaskGeneratorHQ
|
||||
from scripts.sam_log import logger
|
||||
|
||||
|
||||
class AutoSAM:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.auto_sam: SamAutomaticMaskGeneratorHQ = None
|
||||
|
||||
|
||||
def blend_image_and_seg(self, image: np.ndarray, seg: np.ndarray, alpha=0.5) -> Image.Image:
|
||||
image_blend = image * (1 - alpha) + np.array(seg) * alpha
|
||||
return Image.fromarray(image_blend.astype(np.uint8))
|
||||
|
||||
|
||||
def strengthen_semmantic_seg(self, class_ids: np.ndarray, img: np.ndarray) -> np.ndarray:
|
||||
logger.info("AutoSAM strengthening semantic segmentation")
|
||||
import pycocotools.mask as maskUtils
|
||||
semantc_mask = copy.deepcopy(class_ids)
|
||||
annotations = self.auto_sam.generate(img)
|
||||
annotations = sorted(annotations, key=lambda x: x['area'], reverse=True)
|
||||
logger.info(f"AutoSAM generated {len(annotations)} masks")
|
||||
for ann in annotations:
|
||||
valid_mask = torch.tensor(maskUtils.decode(ann['segmentation'])).bool()
|
||||
propose_classes_ids = torch.tensor(class_ids[valid_mask])
|
||||
num_class_proposals = len(torch.unique(propose_classes_ids))
|
||||
if num_class_proposals == 1:
|
||||
semantc_mask[valid_mask] = propose_classes_ids[0].numpy()
|
||||
continue
|
||||
top_1_propose_class_ids = torch.bincount(propose_classes_ids.flatten()).topk(1).indices
|
||||
semantc_mask[valid_mask] = top_1_propose_class_ids.numpy()
|
||||
logger.info("AutoSAM strengthening process end")
|
||||
return semantc_mask
|
||||
|
||||
|
||||
def random_segmentation(self, img: Image.Image) -> Tuple[List[Image.Image], str]:
|
||||
logger.info("AutoSAM generating random segmentation for EditAnything")
|
||||
img_np = np.array(img.convert("RGB"))
|
||||
annotations = self.auto_sam.generate(img_np)
|
||||
logger.info(f"AutoSAM generated {len(annotations)} masks")
|
||||
H, W, _ = img_np.shape
|
||||
color_map = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
detected_map_tmp = np.zeros((H, W), dtype=np.uint16)
|
||||
for idx, annotation in enumerate(annotations):
|
||||
current_seg = annotation['segmentation']
|
||||
color_map[current_seg] = np.random.randint(0, 255, (3))
|
||||
detected_map_tmp[current_seg] = idx + 1
|
||||
detected_map = np.zeros((detected_map_tmp.shape[0], detected_map_tmp.shape[1], 3))
|
||||
detected_map[:, :, 0] = detected_map_tmp % 256
|
||||
detected_map[:, :, 1] = detected_map_tmp // 256
|
||||
try:
|
||||
from scripts.processor import HWC3
|
||||
except:
|
||||
return [], "ControlNet extension not found."
|
||||
detected_map = HWC3(detected_map.astype(np.uint8))
|
||||
logger.info("AutoSAM generation process end")
|
||||
return [self.blend_image_and_seg(img_np, color_map), Image.fromarray(color_map), Image.fromarray(detected_map)], \
|
||||
"Random segmentation done. Left above (0) is blended image, right above (1) is random segmentation, left below (2) is Edit-Anything control input."
|
||||
|
||||
|
||||
def layer_single_image(self, layout_input_image: Image.Image, layout_output_path: str) -> None:
|
||||
img_np = np.array(layout_input_image.convert("RGB"))
|
||||
annotations = self.auto_sam.generate(img_np)
|
||||
logger.info(f"AutoSAM generated {len(annotations)} annotations")
|
||||
annotations = sorted(annotations, key=lambda x: x['area'])
|
||||
for idx, annotation in enumerate(annotations):
|
||||
img_tmp = np.zeros((img_np.shape[0], img_np.shape[1], 3))
|
||||
img_tmp[annotation['segmentation']] = img_np[annotation['segmentation']]
|
||||
img_np[annotation['segmentation']] = np.array([0, 0, 0])
|
||||
img_tmp = Image.fromarray(img_tmp.astype(np.uint8))
|
||||
img_tmp.save(os.path.join(layout_output_path, f"{idx}.png"))
|
||||
img_np = Image.fromarray(img_np.astype(np.uint8))
|
||||
img_np.save(os.path.join(layout_output_path, "leftover.png"))
|
||||
|
||||
|
||||
def image_layer(self, layout_input_image_or_path, layout_output_path: str) -> str:
|
||||
if isinstance(layout_input_image_or_path, str):
|
||||
logger.info("Image layer division batch processing")
|
||||
all_files = glob.glob(os.path.join(layout_input_image_or_path, "*"))
|
||||
for image_index, input_image_file in enumerate(all_files):
|
||||
logger.info(f"Processing {image_index}/{len(all_files)} {input_image_file}")
|
||||
try:
|
||||
input_image = Image.open(input_image_file)
|
||||
output_directory = os.path.join(layout_output_path, os.path.splitext(os.path.basename(input_image_file))[0])
|
||||
from pathlib import Path
|
||||
Path(output_directory).mkdir(exist_ok=True)
|
||||
except:
|
||||
logger.warn(f"File {input_image_file} not image, skipped.")
|
||||
continue
|
||||
self.layer_single_image(input_image, output_directory)
|
||||
else:
|
||||
self.layer_single_image(layout_input_image_or_path, layout_output_path)
|
||||
return "Done"
|
||||
|
||||
|
||||
def semantic_segmentation(self, input_image: Image.Image, annotator_name: str, processor_res: int,
|
||||
use_pixel_perfect: bool, resize_mode: int, target_W: int, target_H: int) -> Tuple[List[Image.Image], str]:
|
||||
if input_image is None:
|
||||
return [], "No input image."
|
||||
if "seg" in annotator_name:
|
||||
try:
|
||||
from scripts.processor import uniformer, oneformer_coco, oneformer_ade20k
|
||||
from scripts.external_code import pixel_perfect_resolution
|
||||
oneformers = {
|
||||
"ade20k": oneformer_ade20k,
|
||||
"coco": oneformer_coco
|
||||
}
|
||||
except:
|
||||
return [], "ControlNet extension not found."
|
||||
input_image_np = np.array(input_image)
|
||||
if use_pixel_perfect:
|
||||
processor_res = pixel_perfect_resolution(input_image_np, resize_mode, target_W, target_H)
|
||||
logger.info("Generating semantic segmentation without SAM")
|
||||
if annotator_name == "seg_ufade20k":
|
||||
original_semantic = uniformer(input_image_np, processor_res)
|
||||
else:
|
||||
dataset = annotator_name.split('_')[-1][2:]
|
||||
original_semantic = oneformers[dataset](input_image_np, processor_res)
|
||||
logger.info("Generating semantic segmentation with SAM")
|
||||
sam_semantic = self.strengthen_semmantic_seg(np.array(original_semantic), input_image_np)
|
||||
output_gallery = [original_semantic, sam_semantic, self.blend_image_and_seg(input_image, original_semantic), self.blend_image_and_seg(input_image, sam_semantic)]
|
||||
return output_gallery, f"Done. Left is segmentation before SAM, right is segmentation after SAM."
|
||||
else:
|
||||
return self.random_segmentation(input_image)
|
||||
|
||||
|
||||
def categorical_mask_image(self, crop_processor: str, crop_processor_res: int, crop_category_input: List[int], crop_input_image: Image.Image,
|
||||
crop_pixel_perfect: bool, crop_resize_mode: int, target_W: int, target_H: int) -> Tuple[np.ndarray, Image.Image]:
|
||||
if crop_input_image is None:
|
||||
return "No input image."
|
||||
try:
|
||||
from scripts.processor import uniformer, oneformer_coco, oneformer_ade20k
|
||||
from scripts.external_code import pixel_perfect_resolution
|
||||
oneformers = {
|
||||
"ade20k": oneformer_ade20k,
|
||||
"coco": oneformer_coco
|
||||
}
|
||||
except:
|
||||
return [], "ControlNet extension not found."
|
||||
filter_classes = crop_category_input
|
||||
if len(filter_classes) == 0:
|
||||
return "No class selected."
|
||||
try:
|
||||
filter_classes = [int(i) for i in filter_classes]
|
||||
except:
|
||||
return "Illegal class id. You may have input some string."
|
||||
crop_input_image_np = np.array(crop_input_image)
|
||||
if crop_pixel_perfect:
|
||||
crop_processor_res = pixel_perfect_resolution(crop_input_image_np, crop_resize_mode, target_W, target_H)
|
||||
crop_input_image_copy = copy.deepcopy(crop_input_image)
|
||||
logger.info(f"Generating categories with processor {crop_processor}")
|
||||
if crop_processor == "seg_ufade20k":
|
||||
original_semantic = uniformer(crop_input_image_np, crop_processor_res)
|
||||
else:
|
||||
dataset = crop_processor.split('_')[-1][2:]
|
||||
original_semantic = oneformers[dataset](crop_input_image_np, crop_processor_res)
|
||||
sam_semantic = self.strengthen_semmantic_seg(np.array(original_semantic), crop_input_image_np)
|
||||
mask = np.zeros(sam_semantic.shape, dtype=np.bool_)
|
||||
from scripts.sam_config import SEMANTIC_CATEGORIES
|
||||
for i in filter_classes:
|
||||
mask[np.equal(sam_semantic, SEMANTIC_CATEGORIES[crop_processor][i])] = True
|
||||
return mask, crop_input_image_copy
|
||||
|
||||
|
||||
def register_auto_sam(self, sam,
|
||||
auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh,
|
||||
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
|
||||
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
|
||||
auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, auto_sam_output_mode):
|
||||
self.auto_sam = SamAutomaticMaskGeneratorHQ(
|
||||
sam, auto_sam_points_per_side, auto_sam_points_per_batch,
|
||||
auto_sam_pred_iou_thresh, auto_sam_stability_score_thresh,
|
||||
auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
|
||||
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
|
||||
auto_sam_crop_n_points_downscale_factor, None,
|
||||
auto_sam_min_mask_region_area, auto_sam_output_mode)
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .model import FastSAM
|
||||
from .predict import FastSAMPredictor
|
||||
from .prompt import FastSAMPrompt
|
||||
# from .val import FastSAMValidator
|
||||
from .decoder import FastSAMDecoder
|
||||
|
||||
__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMDecoder'
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
from .model import FastSAM
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import clip
|
||||
from typing import Optional, List, Tuple, Union
|
||||
|
||||
|
||||
class FastSAMDecoder:
|
||||
def __init__(
|
||||
self,
|
||||
model: FastSAM,
|
||||
device: str='cpu',
|
||||
conf: float=0.4,
|
||||
iou: float=0.9,
|
||||
imgsz: int=1024,
|
||||
retina_masks: bool=True,
|
||||
):
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.retina_masks = retina_masks
|
||||
self.imgsz = imgsz
|
||||
self.conf = conf
|
||||
self.iou = iou
|
||||
self.image = None
|
||||
self.image_embedding = None
|
||||
|
||||
def run_encoder(self, image):
|
||||
if isinstance(image,str):
|
||||
image = np.array(Image.open(image))
|
||||
self.image = image
|
||||
image_embedding = self.model(
|
||||
self.image,
|
||||
device=self.device,
|
||||
retina_masks=self.retina_masks,
|
||||
imgsz=self.imgsz,
|
||||
conf=self.conf,
|
||||
iou=self.iou
|
||||
)
|
||||
return image_embedding[0].numpy()
|
||||
|
||||
def run_decoder(
|
||||
self,
|
||||
image_embedding,
|
||||
point_prompt: Optional[np.ndarray]=None,
|
||||
point_label: Optional[np.ndarray]=None,
|
||||
box_prompt: Optional[np.ndarray]=None,
|
||||
text_prompt: Optional[str]=None,
|
||||
)->np.ndarray:
|
||||
self.image_embedding = image_embedding
|
||||
if point_prompt is not None:
|
||||
ann = self.point_prompt(points=point_prompt, pointlabel=point_label)
|
||||
return ann
|
||||
elif box_prompt is not None:
|
||||
ann = self.box_prompt(bbox=box_prompt)
|
||||
return ann
|
||||
elif text_prompt is not None:
|
||||
ann = self.text_prompt(text=text_prompt)
|
||||
return ann
|
||||
else:
|
||||
return None
|
||||
|
||||
def box_prompt(self, bbox):
|
||||
assert (bbox[2] != 0 and bbox[3] != 0)
|
||||
masks = self.image_embedding.masks.data
|
||||
target_height = self.ori_img.shape[0]
|
||||
target_width = self.ori_img.shape[1]
|
||||
h = masks.shape[1]
|
||||
w = masks.shape[2]
|
||||
if h != target_height or w != target_width:
|
||||
bbox = [
|
||||
int(bbox[0] * w / target_width),
|
||||
int(bbox[1] * h / target_height),
|
||||
int(bbox[2] * w / target_width),
|
||||
int(bbox[3] * h / target_height), ]
|
||||
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
|
||||
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
|
||||
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
|
||||
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
|
||||
|
||||
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
||||
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
||||
|
||||
masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2))
|
||||
orig_masks_area = np.sum(masks, axis=(1, 2))
|
||||
|
||||
union = bbox_area + orig_masks_area - masks_area
|
||||
IoUs = masks_area / union
|
||||
max_iou_index = np.argmax(IoUs)
|
||||
|
||||
return np.array([masks[max_iou_index].cpu().numpy()])
|
||||
|
||||
def point_prompt(self, points, pointlabel): # numpy
|
||||
|
||||
masks = self._format_results(self.results[0], 0)
|
||||
target_height = self.ori_img.shape[0]
|
||||
target_width = self.ori_img.shape[1]
|
||||
h = masks[0]['segmentation'].shape[0]
|
||||
w = masks[0]['segmentation'].shape[1]
|
||||
if h != target_height or w != target_width:
|
||||
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
||||
onemask = np.zeros((h, w))
|
||||
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
|
||||
for i, annotation in enumerate(masks):
|
||||
if type(annotation) == dict:
|
||||
mask = annotation['segmentation']
|
||||
else:
|
||||
mask = annotation
|
||||
for i, point in enumerate(points):
|
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
||||
onemask[mask] = 1
|
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
|
||||
onemask[mask] = 0
|
||||
onemask = onemask >= 1
|
||||
return np.array([onemask])
|
||||
|
||||
def _format_results(self, result, filter=0):
|
||||
annotations = []
|
||||
n = len(result.masks.data)
|
||||
for i in range(n):
|
||||
annotation = {}
|
||||
mask = result.masks.data[i] == 1.0
|
||||
|
||||
if np.sum(mask) < filter:
|
||||
continue
|
||||
annotation['id'] = i
|
||||
annotation['segmentation'] = mask
|
||||
annotation['bbox'] = result.boxes.data[i]
|
||||
annotation['score'] = result.boxes.conf[i]
|
||||
annotation['area'] = annotation['segmentation'].sum()
|
||||
annotations.append(annotation)
|
||||
return annotations
|
||||
|
|
@ -0,0 +1,104 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
YOLO-NAS model interface.
|
||||
|
||||
Usage - Predict:
|
||||
from ultralytics import FastSAM
|
||||
|
||||
model = FastSAM('last.pt')
|
||||
results = model.predict('ultralytics/assets/bus.jpg')
|
||||
"""
|
||||
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
|
||||
from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode
|
||||
from .predict import FastSAMPredictor
|
||||
|
||||
|
||||
class FastSAM(YOLO):
|
||||
|
||||
@smart_inference_mode()
|
||||
def predict(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
Perform prediction using the YOLO model.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
||||
Accepts all source types accepted by the YOLO model.
|
||||
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
||||
**kwargs : Additional keyword arguments passed to the predictor.
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
||||
"""
|
||||
if source is None:
|
||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||
overrides = self.overrides.copy()
|
||||
overrides['conf'] = 0.25
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
overrides['mode'] = kwargs.get('mode', 'predict')
|
||||
assert overrides['mode'] in ['track', 'predict']
|
||||
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
|
||||
self.predictor = FastSAMPredictor(overrides=overrides)
|
||||
self.predictor.setup_model(model=self.model, verbose=False)
|
||||
|
||||
return self.predictor(source, stream=stream)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""Function trains models but raises an error as FastSAM models do not support training."""
|
||||
raise NotImplementedError("FastSAM models don't support training")
|
||||
|
||||
def val(self, **kwargs):
|
||||
"""Run validation given dataset."""
|
||||
overrides = dict(task='segment', mode='val')
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
||||
validator = FastSAM(args=args)
|
||||
validator(model=self.model)
|
||||
self.metrics = validator.metrics
|
||||
return validator.metrics
|
||||
|
||||
@smart_inference_mode()
|
||||
def export(self, **kwargs):
|
||||
"""
|
||||
Export model.
|
||||
|
||||
Args:
|
||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
overrides = dict(task='detect')
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'export'
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.task = self.task
|
||||
if args.imgsz == DEFAULT_CFG.imgsz:
|
||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
if args.batch == DEFAULT_CFG.batch:
|
||||
args.batch = 1 # default to 1 if not modified
|
||||
return Exporter(overrides=args)(model=self.model)
|
||||
|
||||
def info(self, detailed=False, verbose=True):
|
||||
"""
|
||||
Logs model info.
|
||||
|
||||
Args:
|
||||
detailed (bool): Show detailed information about model.
|
||||
verbose (bool): Controls verbosity.
|
||||
"""
|
||||
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Raises error if object has no requested attribute."""
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ops
|
||||
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
||||
from .utils import bbox_iou
|
||||
|
||||
class FastSAMPredictor(DetectionPredictor):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""TODO: filter by classes."""
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nc=len(self.model.names),
|
||||
classes=self.args.classes)
|
||||
|
||||
full_box = torch.zeros_like(p[0][0])
|
||||
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
|
||||
full_box = full_box.view(1, -1)
|
||||
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
|
||||
if critical_iou_index.numel() != 0:
|
||||
full_box[0][4] = p[0][critical_iou_index][:,4]
|
||||
full_box[0][6:] = p[0][critical_iou_index][:,6:]
|
||||
p[0][critical_iou_index] = full_box
|
||||
|
||||
results = []
|
||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
||||
for i, pred in enumerate(p):
|
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||
path = self.batch[0]
|
||||
img_path = path[i] if isinstance(path, list) else path
|
||||
if not len(pred): # save empty boxes
|
||||
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
|
||||
continue
|
||||
if self.args.retina_masks:
|
||||
if not isinstance(orig_imgs, torch.Tensor):
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
||||
else:
|
||||
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
||||
if not isinstance(orig_imgs, torch.Tensor):
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
results.append(
|
||||
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
|
||||
return results
|
||||
|
|
@ -0,0 +1,417 @@
|
|||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
try:
|
||||
import clip # for linear_assignment
|
||||
|
||||
except (ImportError, AssertionError, AttributeError):
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
|
||||
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
|
||||
import clip
|
||||
|
||||
|
||||
class FastSAMPrompt:
|
||||
|
||||
def __init__(self, img_path, results, device='cuda') -> None:
|
||||
# self.img_path = img_path
|
||||
self.device = device
|
||||
self.results = results
|
||||
self.img_path = img_path
|
||||
self.ori_img = cv2.imread(img_path)
|
||||
|
||||
def _segment_image(self, image, bbox):
|
||||
image_array = np.array(image)
|
||||
segmented_image_array = np.zeros_like(image_array)
|
||||
x1, y1, x2, y2 = bbox
|
||||
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
|
||||
segmented_image = Image.fromarray(segmented_image_array)
|
||||
black_image = Image.new('RGB', image.size, (255, 255, 255))
|
||||
# transparency_mask = np.zeros_like((), dtype=np.uint8)
|
||||
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
|
||||
transparency_mask[y1:y2, x1:x2] = 255
|
||||
transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
|
||||
black_image.paste(segmented_image, mask=transparency_mask_image)
|
||||
return black_image
|
||||
|
||||
def _format_results(self, result, filter=0):
|
||||
annotations = []
|
||||
n = len(result.masks.data)
|
||||
for i in range(n):
|
||||
annotation = {}
|
||||
mask = result.masks.data[i] == 1.0
|
||||
|
||||
if torch.sum(mask) < filter:
|
||||
continue
|
||||
annotation['id'] = i
|
||||
annotation['segmentation'] = mask.cpu().numpy()
|
||||
annotation['bbox'] = result.boxes.data[i]
|
||||
annotation['score'] = result.boxes.conf[i]
|
||||
annotation['area'] = annotation['segmentation'].sum()
|
||||
annotations.append(annotation)
|
||||
return annotations
|
||||
|
||||
def filter_masks(annotations): # filte the overlap mask
|
||||
annotations.sort(key=lambda x: x['area'], reverse=True)
|
||||
to_remove = set()
|
||||
for i in range(0, len(annotations)):
|
||||
a = annotations[i]
|
||||
for j in range(i + 1, len(annotations)):
|
||||
b = annotations[j]
|
||||
if i != j and j not in to_remove:
|
||||
# check if
|
||||
if b['area'] < a['area']:
|
||||
if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
|
||||
to_remove.add(j)
|
||||
|
||||
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
|
||||
|
||||
def _get_bbox_from_mask(self, mask):
|
||||
mask = mask.astype(np.uint8)
|
||||
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
x1, y1, w, h = cv2.boundingRect(contours[0])
|
||||
x2, y2 = x1 + w, y1 + h
|
||||
if len(contours) > 1:
|
||||
for b in contours:
|
||||
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
|
||||
# Merge multiple bounding boxes into one.
|
||||
x1 = min(x1, x_t)
|
||||
y1 = min(y1, y_t)
|
||||
x2 = max(x2, x_t + w_t)
|
||||
y2 = max(y2, y_t + h_t)
|
||||
h = y2 - y1
|
||||
w = x2 - x1
|
||||
return [x1, y1, x2, y2]
|
||||
|
||||
def plot(self,
|
||||
annotations,
|
||||
output,
|
||||
bboxes=None,
|
||||
points=None,
|
||||
point_label=None,
|
||||
mask_random_color=True,
|
||||
better_quality=True,
|
||||
retina=False,
|
||||
withContours=True):
|
||||
if isinstance(annotations[0], dict):
|
||||
annotations = [annotation['segmentation'] for annotation in annotations]
|
||||
result_name = os.path.basename(self.img_path)
|
||||
image = self.ori_img
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
original_h = image.shape[0]
|
||||
original_w = image.shape[1]
|
||||
if sys.platform == "darwin":
|
||||
plt.switch_backend("TkAgg")
|
||||
plt.figure(figsize=(original_w / 100, original_h / 100))
|
||||
# Add subplot with no margin.
|
||||
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
||||
plt.margins(0, 0)
|
||||
plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
||||
plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
||||
|
||||
plt.imshow(image)
|
||||
if better_quality:
|
||||
if isinstance(annotations[0], torch.Tensor):
|
||||
annotations = np.array(annotations.cpu())
|
||||
for i, mask in enumerate(annotations):
|
||||
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
|
||||
annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
|
||||
if self.device == 'cpu':
|
||||
annotations = np.array(annotations)
|
||||
self.fast_show_mask(
|
||||
annotations,
|
||||
plt.gca(),
|
||||
random_color=mask_random_color,
|
||||
bboxes=bboxes,
|
||||
points=points,
|
||||
pointlabel=point_label,
|
||||
retinamask=retina,
|
||||
target_height=original_h,
|
||||
target_width=original_w,
|
||||
)
|
||||
else:
|
||||
if isinstance(annotations[0], np.ndarray):
|
||||
annotations = torch.from_numpy(annotations)
|
||||
self.fast_show_mask_gpu(
|
||||
annotations,
|
||||
plt.gca(),
|
||||
random_color=mask_random_color,
|
||||
bboxes=bboxes,
|
||||
points=points,
|
||||
pointlabel=point_label,
|
||||
retinamask=retina,
|
||||
target_height=original_h,
|
||||
target_width=original_w,
|
||||
)
|
||||
if isinstance(annotations, torch.Tensor):
|
||||
annotations = annotations.cpu().numpy()
|
||||
if withContours:
|
||||
contour_all = []
|
||||
temp = np.zeros((original_h, original_w, 1))
|
||||
for i, mask in enumerate(annotations):
|
||||
if type(mask) == dict:
|
||||
mask = mask['segmentation']
|
||||
annotation = mask.astype(np.uint8)
|
||||
if not retina:
|
||||
annotation = cv2.resize(
|
||||
annotation,
|
||||
(original_w, original_h),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
for contour in contours:
|
||||
contour_all.append(contour)
|
||||
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
|
||||
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
|
||||
contour_mask = temp / 255 * color.reshape(1, 1, -1)
|
||||
plt.imshow(contour_mask)
|
||||
|
||||
save_path = output
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
plt.axis('off')
|
||||
fig = plt.gcf()
|
||||
plt.draw()
|
||||
|
||||
try:
|
||||
buf = fig.canvas.tostring_rgb()
|
||||
except AttributeError:
|
||||
fig.canvas.draw()
|
||||
buf = fig.canvas.tostring_rgb()
|
||||
cols, rows = fig.canvas.get_width_height()
|
||||
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
|
||||
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
|
||||
|
||||
# CPU post process
|
||||
def fast_show_mask(
|
||||
self,
|
||||
annotation,
|
||||
ax,
|
||||
random_color=False,
|
||||
bboxes=None,
|
||||
points=None,
|
||||
pointlabel=None,
|
||||
retinamask=True,
|
||||
target_height=960,
|
||||
target_width=960,
|
||||
):
|
||||
msak_sum = annotation.shape[0]
|
||||
height = annotation.shape[1]
|
||||
weight = annotation.shape[2]
|
||||
#Sort annotations based on area.
|
||||
areas = np.sum(annotation, axis=(1, 2))
|
||||
sorted_indices = np.argsort(areas)
|
||||
annotation = annotation[sorted_indices]
|
||||
|
||||
index = (annotation != 0).argmax(axis=0)
|
||||
if random_color:
|
||||
color = np.random.random((msak_sum, 1, 1, 3))
|
||||
else:
|
||||
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
|
||||
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
|
||||
visual = np.concatenate([color, transparency], axis=-1)
|
||||
mask_image = np.expand_dims(annotation, -1) * visual
|
||||
|
||||
show = np.zeros((height, weight, 4))
|
||||
h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
|
||||
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
||||
# Use vectorized indexing to update the values of 'show'.
|
||||
show[h_indices, w_indices, :] = mask_image[indices]
|
||||
if bboxes is not None:
|
||||
for bbox in bboxes:
|
||||
x1, y1, x2, y2 = bbox
|
||||
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
|
||||
# draw point
|
||||
if points is not None:
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
s=20,
|
||||
c='y',
|
||||
)
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
s=20,
|
||||
c='m',
|
||||
)
|
||||
|
||||
if not retinamask:
|
||||
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
|
||||
ax.imshow(show)
|
||||
|
||||
def fast_show_mask_gpu(
|
||||
self,
|
||||
annotation,
|
||||
ax,
|
||||
random_color=False,
|
||||
bboxes=None,
|
||||
points=None,
|
||||
pointlabel=None,
|
||||
retinamask=True,
|
||||
target_height=960,
|
||||
target_width=960,
|
||||
):
|
||||
msak_sum = annotation.shape[0]
|
||||
height = annotation.shape[1]
|
||||
weight = annotation.shape[2]
|
||||
areas = torch.sum(annotation, dim=(1, 2))
|
||||
sorted_indices = torch.argsort(areas, descending=False)
|
||||
annotation = annotation[sorted_indices]
|
||||
# Find the index of the first non-zero value at each position.
|
||||
index = (annotation != 0).to(torch.long).argmax(dim=0)
|
||||
if random_color:
|
||||
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
|
||||
else:
|
||||
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([
|
||||
30 / 255, 144 / 255, 255 / 255]).to(annotation.device)
|
||||
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
|
||||
visual = torch.cat([color, transparency], dim=-1)
|
||||
mask_image = torch.unsqueeze(annotation, -1) * visual
|
||||
# Select data according to the index. The index indicates which batch's data to choose at each position, converting the mask_image into a single batch form.
|
||||
show = torch.zeros((height, weight, 4)).to(annotation.device)
|
||||
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij')
|
||||
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
||||
# Use vectorized indexing to update the values of 'show'.
|
||||
show[h_indices, w_indices, :] = mask_image[indices]
|
||||
show_cpu = show.cpu().numpy()
|
||||
if bboxes is not None:
|
||||
for bbox in bboxes:
|
||||
x1, y1, x2, y2 = bbox
|
||||
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
|
||||
# draw point
|
||||
if points is not None:
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
s=20,
|
||||
c='y',
|
||||
)
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
s=20,
|
||||
c='m',
|
||||
)
|
||||
if not retinamask:
|
||||
show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
|
||||
ax.imshow(show_cpu)
|
||||
|
||||
# clip
|
||||
@torch.no_grad()
|
||||
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
|
||||
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
||||
tokenized_text = clip.tokenize([search_text]).to(device)
|
||||
stacked_images = torch.stack(preprocessed_images)
|
||||
image_features = model.encode_image(stacked_images)
|
||||
text_features = model.encode_text(tokenized_text)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
probs = 100.0 * image_features @ text_features.T
|
||||
return probs[:, 0].softmax(dim=0)
|
||||
|
||||
def _crop_image(self, format_results):
|
||||
|
||||
image = Image.fromarray(cv2.cvtColor(self.ori_img, cv2.COLOR_BGR2RGB))
|
||||
ori_w, ori_h = image.size
|
||||
annotations = format_results
|
||||
mask_h, mask_w = annotations[0]['segmentation'].shape
|
||||
if ori_w != mask_w or ori_h != mask_h:
|
||||
image = image.resize((mask_w, mask_h))
|
||||
cropped_boxes = []
|
||||
cropped_images = []
|
||||
not_crop = []
|
||||
filter_id = []
|
||||
# annotations, _ = filter_masks(annotations)
|
||||
# filter_id = list(_)
|
||||
for _, mask in enumerate(annotations):
|
||||
if np.sum(mask['segmentation']) <= 100:
|
||||
filter_id.append(_)
|
||||
continue
|
||||
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
|
||||
cropped_boxes.append(self._segment_image(image, bbox))
|
||||
# cropped_boxes.append(segment_image(image,mask["segmentation"]))
|
||||
cropped_images.append(bbox) # Save the bounding box of the cropped image.
|
||||
|
||||
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
||||
|
||||
def box_prompt(self, bbox=None, bboxes=None):
|
||||
|
||||
assert bbox or bboxes
|
||||
if bboxes is None:
|
||||
bboxes = [bbox]
|
||||
max_iou_index = []
|
||||
for bbox in bboxes:
|
||||
assert (bbox[2] != 0 and bbox[3] != 0)
|
||||
masks = self.results[0].masks.data
|
||||
target_height = self.ori_img.shape[0]
|
||||
target_width = self.ori_img.shape[1]
|
||||
h = masks.shape[1]
|
||||
w = masks.shape[2]
|
||||
if h != target_height or w != target_width:
|
||||
bbox = [
|
||||
int(bbox[0] * w / target_width),
|
||||
int(bbox[1] * h / target_height),
|
||||
int(bbox[2] * w / target_width),
|
||||
int(bbox[3] * h / target_height), ]
|
||||
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
|
||||
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
|
||||
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
|
||||
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
|
||||
|
||||
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
||||
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
||||
|
||||
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
|
||||
orig_masks_area = torch.sum(masks, dim=(1, 2))
|
||||
|
||||
union = bbox_area + orig_masks_area - masks_area
|
||||
IoUs = masks_area / union
|
||||
max_iou_index.append(int(torch.argmax(IoUs)))
|
||||
max_iou_index = list(set(max_iou_index))
|
||||
return np.array(masks[max_iou_index].cpu().numpy())
|
||||
|
||||
def point_prompt(self, points, pointlabel): # numpy
|
||||
|
||||
masks = self._format_results(self.results[0], 0)
|
||||
target_height = self.ori_img.shape[0]
|
||||
target_width = self.ori_img.shape[1]
|
||||
h = masks[0]['segmentation'].shape[0]
|
||||
w = masks[0]['segmentation'].shape[1]
|
||||
if h != target_height or w != target_width:
|
||||
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
||||
onemask = np.zeros((h, w))
|
||||
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
|
||||
for i, annotation in enumerate(masks):
|
||||
if type(annotation) == dict:
|
||||
mask = annotation['segmentation']
|
||||
else:
|
||||
mask = annotation
|
||||
for i, point in enumerate(points):
|
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
||||
onemask[mask] = 1
|
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
|
||||
onemask[mask] = 0
|
||||
onemask = onemask >= 1
|
||||
return np.array([onemask])
|
||||
|
||||
def text_prompt(self, text):
|
||||
format_results = self._format_results(self.results[0], 0)
|
||||
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
||||
clip_model, preprocess = clip.load('ViT-B/32', device=self.device)
|
||||
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
|
||||
max_idx = scores.argsort()
|
||||
max_idx = max_idx[-1]
|
||||
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
||||
return np.array([annotations[max_idx]['segmentation']])
|
||||
|
||||
def everything_prompt(self):
|
||||
return self.results[0].masks.data
|
||||
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
import torch
|
||||
|
||||
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
||||
'''Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
||||
Args:
|
||||
boxes: (n, 4)
|
||||
image_shape: (height, width)
|
||||
threshold: pixel threshold
|
||||
Returns:
|
||||
adjusted_boxes: adjusted bounding boxes
|
||||
'''
|
||||
|
||||
# Image dimensions
|
||||
h, w = image_shape
|
||||
|
||||
# Adjust boxes
|
||||
boxes[:, 0] = torch.where(boxes[:, 0] < threshold, 0, boxes[:, 0]) # x1
|
||||
boxes[:, 1] = torch.where(boxes[:, 1] < threshold, 0, boxes[:, 1]) # y1
|
||||
boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, w, boxes[:, 2]) # x2
|
||||
boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, h, boxes[:, 3]) # y2
|
||||
|
||||
return boxes
|
||||
|
||||
def convert_box_xywh_to_xyxy(box):
|
||||
x1 = box[0]
|
||||
y1 = box[1]
|
||||
x2 = box[0] + box[2]
|
||||
y2 = box[1] + box[3]
|
||||
return [x1, y1, x2, y2]
|
||||
|
||||
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
|
||||
'''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
|
||||
Args:
|
||||
box1: (4, )
|
||||
boxes: (n, 4)
|
||||
Returns:
|
||||
high_iou_indices: Indices of boxes with IoU > thres
|
||||
'''
|
||||
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
|
||||
# obtain coordinates for intersections
|
||||
x1 = torch.max(box1[0], boxes[:, 0])
|
||||
y1 = torch.max(box1[1], boxes[:, 1])
|
||||
x2 = torch.min(box1[2], boxes[:, 2])
|
||||
y2 = torch.min(box1[3], boxes[:, 3])
|
||||
|
||||
# compute the area of intersection
|
||||
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
|
||||
|
||||
# compute the area of both individual boxes
|
||||
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||||
box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
# compute the area of union
|
||||
union = box1_area + box2_area - intersection
|
||||
|
||||
# compute the IoU
|
||||
iou = intersection / union # Should be shape (n, )
|
||||
if raw_output:
|
||||
if iou.numel() == 0:
|
||||
return 0
|
||||
return iou
|
||||
|
||||
# get indices of boxes with IoU > thres
|
||||
high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
|
||||
|
||||
return high_iou_indices
|
||||
|
|
@ -31,7 +31,12 @@ class SamM2M(Module):
|
|||
except:
|
||||
mam_url = "https://huggingface.co/conrevo/SAM4WebUI-Extension-Models/resolve/main/mam.pth"
|
||||
logger.info(f"Loading mam from url: {mam_url} to path: {ckpt_path}, device: {self.m2m_device}")
|
||||
state_dict = torch.hub.load_state_dict_from_url(mam_url, ckpt_path, self.m2m_device)
|
||||
try:
|
||||
state_dict = torch.hub.load_state_dict_from_url(mam_url, ckpt_path, self.m2m_device)
|
||||
except:
|
||||
error_msg = f"Unable to connect to {mam_url}, thus unable to download Matting-Anything model."
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
self.m2m.load_state_dict(state_dict)
|
||||
self.m2m.eval()
|
||||
|
||||
|
|
@ -39,6 +44,7 @@ class SamM2M(Module):
|
|||
def forward(self, features: torch.Tensor, image: torch.Tensor,
|
||||
low_res_masks: torch.Tensor, masks: torch.Tensor,
|
||||
ori_shape: torch.Tensor, pad_shape: torch.Tensor, guidance_mode: str):
|
||||
logger.info("Applying Matting-Anything.")
|
||||
self.m2m.to(self.m2m_device)
|
||||
pred = self.m2m(features, image, low_res_masks)
|
||||
alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred['alpha_os1'], pred['alpha_os4'], pred['alpha_os8']
|
||||
Loading…
Reference in New Issue