implement dispatch logic

pull/27/head
udon-universe 2023-05-07 19:01:02 +09:00
parent daaeb3ca1e
commit 7835013f00
1 changed files with 13 additions and 4 deletions

View File

@ -25,11 +25,20 @@ sam_model_dir = os.path.join(
model_list = [f for f in os.listdir(sam_model_dir) if os.path.isfile(
os.path.join(sam_model_dir, f)) and f.split('.')[-1] != 'txt']
def processing(single_image, batch_image, input_tab_state, *rem_args):
# 0: single
if (input_tab_state == 0):
result = process_image(single_image, *rem_args)
return result
# 1 (or ohter): batch
else:
result = process_image(batch_image, *rem_args)
return result
def processing(single_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold):
image = pil2cv(single_image)
def process_image(target_image, *rem_args):
image = pil2cv(target_image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask, image = get_foreground(image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold)
mask, image = get_foreground(image, *rem_args)
return image, mask
class Script(scripts.Script):
@ -91,7 +100,7 @@ def on_ui_tabs():
input_tab_batch.select(fn=lambda: 1, inputs=[], outputs=[input_tab_state])
submit.click(
processing,
inputs=[single_image, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold],
inputs=[single_image, batch_image, input_tab_state, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold],
outputs=gallery
)