diff --git a/scripts/tools.py b/scripts/tools.py index fbce4c3..451435e 100644 --- a/scripts/tools.py +++ b/scripts/tools.py @@ -21,8 +21,13 @@ from collections import OrderedDict from PIL import Image model_cache = OrderedDict() -sam_model_dir = os.path.join( - extensions_dir, "PBRemTools/models/") +models_path = shared.models_path +sams_dir = os.path.join(models_path, "sam") +if os.path.exists(sams_dir): + sam_model_dir = sams_dir +else: + sam_model_dir = os.path.join( + extensions_dir, "PBRemTools/models/") model_list = [f for f in os.listdir(sam_model_dir) if os.path.isfile( os.path.join(sam_model_dir, f)) and f.split('.')[-1] != 'txt']