Add mobile_sam with controlnet_aux (#3000)
* Add mobile_sam with controlnet_aux for CNXL_Unionpull/3001/head
parent
3ff69b9ea3
commit
0baecb5f09
|
|
@ -0,0 +1,49 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from modules import devices
|
||||||
|
from annotator.util import load_model
|
||||||
|
from annotator.annotator_path import models_path
|
||||||
|
|
||||||
|
from controlnet_aux import SamDetector
|
||||||
|
from controlnet_aux.segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
||||||
|
|
||||||
|
class SamDetector_Aux(SamDetector):
|
||||||
|
|
||||||
|
model_dir = os.path.join(models_path, "mobile_sam")
|
||||||
|
|
||||||
|
def __init__(self, mask_generator: SamAutomaticMaskGenerator, sam):
|
||||||
|
super().__init__(mask_generator)
|
||||||
|
self.device = devices.device
|
||||||
|
self.model = sam.to(self.device).eval()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls):
|
||||||
|
"""
|
||||||
|
Possible model_type : vit_h, vit_l, vit_b, vit_t
|
||||||
|
download weights from https://huggingface.co/dhkim2810/MobileSAM
|
||||||
|
"""
|
||||||
|
remote_url = os.environ.get(
|
||||||
|
"CONTROLNET_MOBILE_SAM_MODEL_URL",
|
||||||
|
"https://huggingface.co/dhkim2810/MobileSAM/resolve/main/mobile_sam.pt",
|
||||||
|
)
|
||||||
|
model_path = load_model(
|
||||||
|
"mobile_sam.pt", remote_url=remote_url, model_dir=cls.model_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
sam = sam_model_registry["vit_t"](checkpoint=model_path)
|
||||||
|
|
||||||
|
cls.model = sam.to(devices.device).eval()
|
||||||
|
|
||||||
|
mask_generator = SamAutomaticMaskGenerator(cls.model)
|
||||||
|
|
||||||
|
return cls(mask_generator, sam)
|
||||||
|
|
||||||
|
def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="cv2", **kwargs) -> np.ndarray:
|
||||||
|
self.model.to(self.device)
|
||||||
|
image = super().__call__(input_image=input_image, detect_resolution=detect_resolution, image_resolution=image_resolution, output_type=output_type, **kwargs)
|
||||||
|
return np.array(image).astype(np.uint8)
|
||||||
|
|
@ -9,3 +9,4 @@ matplotlib
|
||||||
facexlib
|
facexlib
|
||||||
timm<=0.9.5
|
timm<=0.9.5
|
||||||
pydantic<=1.10.17
|
pydantic<=1.10.17
|
||||||
|
controlnet_aux
|
||||||
|
|
@ -5,4 +5,5 @@ from .lama_inpaint import *
|
||||||
from .ip_adapter_auto import *
|
from .ip_adapter_auto import *
|
||||||
from .normal_dsine import *
|
from .normal_dsine import *
|
||||||
from .model_free_preprocessors import *
|
from .model_free_preprocessors import *
|
||||||
from .legacy.legacy_preprocessors import *
|
from .legacy.legacy_preprocessors import *
|
||||||
|
from .mobile_sam import *
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
from annotator.mobile_sam import SamDetector_Aux
|
||||||
|
from scripts.supported_preprocessor import Preprocessor
|
||||||
|
|
||||||
|
class PreprocessorMobileSam(Preprocessor):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(name="mobile_sam")
|
||||||
|
self.tags = ["Segmentation"]
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_image,
|
||||||
|
resolution,
|
||||||
|
slider_1=None,
|
||||||
|
slider_2=None,
|
||||||
|
slider_3=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if self.model is None:
|
||||||
|
self.model = SamDetector_Aux.from_pretrained()
|
||||||
|
|
||||||
|
result = self.model(input_image, detect_resolution=resolution, image_resolution=resolution, output_type="cv2")
|
||||||
|
return result
|
||||||
|
|
||||||
|
Preprocessor.add_supported_preprocessor(PreprocessorMobileSam())
|
||||||
Loading…
Reference in New Issue