121 lines
4.9 KiB
Python
121 lines
4.9 KiB
Python
import os
|
|
import platform
|
|
|
|
import torch
|
|
from modules import devices
|
|
|
|
from fast_sam import FastSamAutomaticMaskGenerator, fast_sam_model_registry
|
|
from ia_check_versions import ia_check_versions
|
|
from ia_config import get_webui_setting
|
|
from ia_logging import ia_logging
|
|
from ia_threading import torch_default_load_cd
|
|
from mobile_sam import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorMobile
|
|
from mobile_sam import SamPredictor as SamPredictorMobile
|
|
from mobile_sam import sam_model_registry as sam_model_registry_mobile
|
|
from segment_anything_fb import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
|
|
from segment_anything_hq import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorHQ
|
|
from segment_anything_hq import SamPredictor as SamPredictorHQ
|
|
from segment_anything_hq import sam_model_registry as sam_model_registry_hq
|
|
|
|
|
|
@torch_default_load_cd()
|
|
def get_sam_mask_generator(sam_checkpoint, anime_style_chk=False):
|
|
"""Get SAM mask generator.
|
|
|
|
Args:
|
|
sam_checkpoint (str): SAM checkpoint path
|
|
|
|
Returns:
|
|
SamAutomaticMaskGenerator or None: SAM mask generator
|
|
"""
|
|
# model_type = "vit_h"
|
|
if "_hq_" in os.path.basename(sam_checkpoint):
|
|
model_type = os.path.basename(sam_checkpoint)[7:12]
|
|
sam_model_registry_local = sam_model_registry_hq
|
|
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorHQ
|
|
points_per_batch = 32
|
|
elif "FastSAM" in os.path.basename(sam_checkpoint):
|
|
model_type = os.path.splitext(os.path.basename(sam_checkpoint))[0]
|
|
sam_model_registry_local = fast_sam_model_registry
|
|
SamAutomaticMaskGeneratorLocal = FastSamAutomaticMaskGenerator
|
|
points_per_batch = None
|
|
elif "mobile_sam" in os.path.basename(sam_checkpoint):
|
|
model_type = "vit_t"
|
|
sam_model_registry_local = sam_model_registry_mobile
|
|
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorMobile
|
|
points_per_batch = 64
|
|
else:
|
|
model_type = os.path.basename(sam_checkpoint)[4:9]
|
|
sam_model_registry_local = sam_model_registry
|
|
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGenerator
|
|
points_per_batch = 64
|
|
|
|
pred_iou_thresh = 0.88 if not anime_style_chk else 0.83
|
|
stability_score_thresh = 0.95 if not anime_style_chk else 0.9
|
|
|
|
if os.path.isfile(sam_checkpoint):
|
|
sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
|
|
if platform.system() == "Darwin":
|
|
if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
|
|
sam.to(device=torch.device("cpu"))
|
|
else:
|
|
sam.to(device=torch.device("mps"))
|
|
else:
|
|
if get_webui_setting("inpaint_anything_sam_oncpu", False):
|
|
ia_logging.info("SAM is running on CPU... (the option has been checked)")
|
|
sam.to(device=devices.cpu)
|
|
else:
|
|
sam.to(device=devices.device)
|
|
sam_mask_generator = SamAutomaticMaskGeneratorLocal(
|
|
model=sam, points_per_batch=points_per_batch, pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
|
|
else:
|
|
sam_mask_generator = None
|
|
|
|
return sam_mask_generator
|
|
|
|
|
|
@torch_default_load_cd()
|
|
def get_sam_predictor(sam_checkpoint):
|
|
"""Get SAM predictor.
|
|
|
|
Args:
|
|
sam_checkpoint (str): SAM checkpoint path
|
|
|
|
Returns:
|
|
SamPredictor or None: SAM predictor
|
|
"""
|
|
# model_type = "vit_h"
|
|
if "_hq_" in os.path.basename(sam_checkpoint):
|
|
model_type = os.path.basename(sam_checkpoint)[7:12]
|
|
sam_model_registry_local = sam_model_registry_hq
|
|
SamPredictorLocal = SamPredictorHQ
|
|
elif "FastSAM" in os.path.basename(sam_checkpoint):
|
|
raise NotImplementedError("FastSAM predictor is not implemented yet.")
|
|
elif "mobile_sam" in os.path.basename(sam_checkpoint):
|
|
model_type = "vit_t"
|
|
sam_model_registry_local = sam_model_registry_mobile
|
|
SamPredictorLocal = SamPredictorMobile
|
|
else:
|
|
model_type = os.path.basename(sam_checkpoint)[4:9]
|
|
sam_model_registry_local = sam_model_registry
|
|
SamPredictorLocal = SamPredictor
|
|
|
|
if os.path.isfile(sam_checkpoint):
|
|
sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
|
|
if platform.system() == "Darwin":
|
|
if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
|
|
sam.to(device=torch.device("cpu"))
|
|
else:
|
|
sam.to(device=torch.device("mps"))
|
|
else:
|
|
if get_webui_setting("inpaint_anything_sam_oncpu", False):
|
|
ia_logging.info("SAM is running on CPU... (the option has been checked)")
|
|
sam.to(device=devices.cpu)
|
|
else:
|
|
sam.to(device=devices.device)
|
|
sam_predictor = SamPredictorLocal(sam)
|
|
else:
|
|
sam_predictor = None
|
|
|
|
return sam_predictor
|