Add mobile_sam with controlnet_aux (#3000)

* Add mobile_sam with controlnet_aux for CNXL_Union
pull/3001/head
青龍聖者@bdsqlsz 2024-07-15 15:33:42 +08:00 committed by GitHub
parent 3ff69b9ea3
commit 0baecb5f09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 77 additions and 1 deletions

View File

@ -0,0 +1,49 @@
from __future__ import print_function
import os
import numpy as np
from PIL import Image
from typing import Union
from modules import devices
from annotator.util import load_model
from annotator.annotator_path import models_path
from controlnet_aux import SamDetector
from controlnet_aux.segment_anything import sam_model_registry, SamAutomaticMaskGenerator
class SamDetector_Aux(SamDetector):
model_dir = os.path.join(models_path, "mobile_sam")
def __init__(self, mask_generator: SamAutomaticMaskGenerator, sam):
super().__init__(mask_generator)
self.device = devices.device
self.model = sam.to(self.device).eval()
@classmethod
def from_pretrained(cls):
"""
Possible model_type : vit_h, vit_l, vit_b, vit_t
download weights from https://huggingface.co/dhkim2810/MobileSAM
"""
remote_url = os.environ.get(
"CONTROLNET_MOBILE_SAM_MODEL_URL",
"https://huggingface.co/dhkim2810/MobileSAM/resolve/main/mobile_sam.pt",
)
model_path = load_model(
"mobile_sam.pt", remote_url=remote_url, model_dir=cls.model_dir
)
sam = sam_model_registry["vit_t"](checkpoint=model_path)
cls.model = sam.to(devices.device).eval()
mask_generator = SamAutomaticMaskGenerator(cls.model)
return cls(mask_generator, sam)
def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="cv2", **kwargs) -> np.ndarray:
self.model.to(self.device)
image = super().__call__(input_image=input_image, detect_resolution=detect_resolution, image_resolution=image_resolution, output_type=output_type, **kwargs)
return np.array(image).astype(np.uint8)

View File

@ -9,3 +9,4 @@ matplotlib
facexlib
timm<=0.9.5
pydantic<=1.10.17
controlnet_aux

View File

@ -5,4 +5,5 @@ from .lama_inpaint import *
from .ip_adapter_auto import *
from .normal_dsine import *
from .model_free_preprocessors import *
from .legacy.legacy_preprocessors import *
from .legacy.legacy_preprocessors import *
from .mobile_sam import *

View File

@ -0,0 +1,25 @@
from annotator.mobile_sam import SamDetector_Aux
from scripts.supported_preprocessor import Preprocessor
class PreprocessorMobileSam(Preprocessor):
def __init__(self):
super().__init__(name="mobile_sam")
self.tags = ["Segmentation"]
self.model = None
def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
**kwargs
):
if self.model is None:
self.model = SamDetector_Aux.from_pretrained()
result = self.model(input_image, detect_resolution=resolution, image_resolution=resolution, output_type="cv2")
return result
Preprocessor.add_supported_preprocessor(PreprocessorMobileSam())