Round parameters of arguments for SAM 2 class

main
Uminosachi 2024-08-04 17:39:19 +09:00
parent 71bdef2ba7
commit 28fe6a7a0f
1 changed files with 3 additions and 3 deletions

View File

@ -70,7 +70,7 @@ def get_sam_mask_generator(sam_checkpoint, anime_style_chk=False):
Returns:
SamAutomaticMaskGenerator or None: SAM mask generator
"""
# model_type = "vit_h"
points_per_batch = 64
if "_hq_" in os.path.basename(sam_checkpoint):
model_type = os.path.basename(sam_checkpoint)[7:12]
sam_model_registry_local = sam_model_registry_hq
@ -101,8 +101,8 @@ def get_sam_mask_generator(sam_checkpoint, anime_style_chk=False):
stability_score_thresh = 0.95 if not anime_style_chk else 0.9
if "sam2_" in model_type:
pred_iou_thresh = pred_iou_thresh - 0.18
stability_score_thresh = stability_score_thresh - 0.03
pred_iou_thresh = round(pred_iou_thresh - 0.18, 2)
stability_score_thresh = round(stability_score_thresh - 0.03, 2)
sam2_gen_kwargs = dict(
points_per_side=64,
points_per_batch=points_per_batch,