Add mobile_sam with controlnet_aux (#3000)
* Add mobile_sam with controlnet_aux for CNXL_Unionpull/3001/head
parent
3ff69b9ea3
commit
0baecb5f09
|
|
@ -0,0 +1,49 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
|
||||
from modules import devices
|
||||
from annotator.util import load_model
|
||||
from annotator.annotator_path import models_path
|
||||
|
||||
from controlnet_aux import SamDetector
|
||||
from controlnet_aux.segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
||||
|
||||
class SamDetector_Aux(SamDetector):
|
||||
|
||||
model_dir = os.path.join(models_path, "mobile_sam")
|
||||
|
||||
def __init__(self, mask_generator: SamAutomaticMaskGenerator, sam):
|
||||
super().__init__(mask_generator)
|
||||
self.device = devices.device
|
||||
self.model = sam.to(self.device).eval()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls):
|
||||
"""
|
||||
Possible model_type : vit_h, vit_l, vit_b, vit_t
|
||||
download weights from https://huggingface.co/dhkim2810/MobileSAM
|
||||
"""
|
||||
remote_url = os.environ.get(
|
||||
"CONTROLNET_MOBILE_SAM_MODEL_URL",
|
||||
"https://huggingface.co/dhkim2810/MobileSAM/resolve/main/mobile_sam.pt",
|
||||
)
|
||||
model_path = load_model(
|
||||
"mobile_sam.pt", remote_url=remote_url, model_dir=cls.model_dir
|
||||
)
|
||||
|
||||
sam = sam_model_registry["vit_t"](checkpoint=model_path)
|
||||
|
||||
cls.model = sam.to(devices.device).eval()
|
||||
|
||||
mask_generator = SamAutomaticMaskGenerator(cls.model)
|
||||
|
||||
return cls(mask_generator, sam)
|
||||
|
||||
def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="cv2", **kwargs) -> np.ndarray:
|
||||
self.model.to(self.device)
|
||||
image = super().__call__(input_image=input_image, detect_resolution=detect_resolution, image_resolution=image_resolution, output_type=output_type, **kwargs)
|
||||
return np.array(image).astype(np.uint8)
|
||||
|
|
@ -9,3 +9,4 @@ matplotlib
|
|||
facexlib
|
||||
timm<=0.9.5
|
||||
pydantic<=1.10.17
|
||||
controlnet_aux
|
||||
|
|
@ -6,3 +6,4 @@ from .ip_adapter_auto import *
|
|||
from .normal_dsine import *
|
||||
from .model_free_preprocessors import *
|
||||
from .legacy.legacy_preprocessors import *
|
||||
from .mobile_sam import *
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
from annotator.mobile_sam import SamDetector_Aux
|
||||
from scripts.supported_preprocessor import Preprocessor
|
||||
|
||||
class PreprocessorMobileSam(Preprocessor):
|
||||
def __init__(self):
|
||||
super().__init__(name="mobile_sam")
|
||||
self.tags = ["Segmentation"]
|
||||
self.model = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_image,
|
||||
resolution,
|
||||
slider_1=None,
|
||||
slider_2=None,
|
||||
slider_3=None,
|
||||
**kwargs
|
||||
):
|
||||
if self.model is None:
|
||||
self.model = SamDetector_Aux.from_pretrained()
|
||||
|
||||
result = self.model(input_image, detect_resolution=resolution, image_resolution=resolution, output_type="cv2")
|
||||
return result
|
||||
|
||||
Preprocessor.add_supported_preprocessor(PreprocessorMobileSam())
|
||||
Loading…
Reference in New Issue