diff --git a/scripts/main.py b/scripts/main.py index f8b59f3..f7fab0f 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -25,11 +25,20 @@ sam_model_dir = os.path.join( model_list = [f for f in os.listdir(sam_model_dir) if os.path.isfile( os.path.join(sam_model_dir, f)) and f.split('.')[-1] != 'txt'] +def processing(single_image, batch_image, input_tab_state, *rem_args): + # 0: single + if (input_tab_state == 0): + result = process_image(single_image, *rem_args) + return result + # 1 (or ohter): batch + else: + result = process_image(batch_image, *rem_args) + return result -def processing(single_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold): - image = pil2cv(single_image) +def process_image(target_image, *rem_args): + image = pil2cv(target_image) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - mask, image = get_foreground(image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold) + mask, image = get_foreground(image, *rem_args) return image, mask class Script(scripts.Script): @@ -91,7 +100,7 @@ def on_ui_tabs(): input_tab_batch.select(fn=lambda: 1, inputs=[], outputs=[input_tab_state]) submit.click( processing, - inputs=[single_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold], + inputs=[single_image, batch_image, input_tab_state, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold], outputs=gallery )