From 4f07966fd61ee2c71db3d448e91fafafef01efcc Mon Sep 17 00:00:00 2001 From: udon-universe <128375799+udon-universe@users.noreply.github.com> Date: Wed, 10 May 2023 02:50:20 +0900 Subject: [PATCH] implement batch dir mode --- scripts/main.py | 47 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/scripts/main.py b/scripts/main.py index 02e4d50..db568b7 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -8,6 +8,8 @@ import gradio as gr import modules.scripts as scripts from modules import script_callbacks +import modules.shared as shared +from modules import images from scripts.td_abg import get_foreground from scripts.convertor import pil2cv @@ -25,13 +27,13 @@ sam_model_dir = os.path.join( model_list = [f for f in os.listdir(sam_model_dir) if os.path.isfile( os.path.join(sam_model_dir, f)) and f.split('.')[-1] != 'txt'] -def processing(single_image, batch_image, input_tab_state, *rem_args): +def processing(single_image, batch_image, input_dir, output_dir, show_result, input_tab_state, *rem_args): # 0: single if (input_tab_state == 0): processed = process_image(single_image, *rem_args) return processed - # 1 (or ohter): batch - else: + # 1: batch + elif (input_tab_state == 1): processed = [] for i in batch_image: image = Image.open(i) @@ -39,6 +41,36 @@ def processing(single_image, batch_image, input_tab_state, *rem_args): processed.append(r[0]) processed.append(r[1]) return processed + # 2: batch dir (or other) + else: + processed = [] + files = shared.listfiles(input_dir) + for f in files: + try: + image = Image.open(f) + except Exception: + continue + base, mask = process_image(image, *rem_args) + processed.append(base) + processed.append(mask) + basename = os.path.splitext(os.path.basename(f))[0] + ext = os.path.splitext(f)[1][1:] + images.save_image( + Image.fromarray(base), + path=output_dir, + basename=basename, + extension=ext, + ) + images.save_image( + Image.fromarray(mask), + path=output_dir, + basename=basename+"_mask", + extension=ext, + ) + if (show_result): + return processed + else: + return None def process_image(target_image, *rem_args): image = pil2cv(target_image) @@ -69,6 +101,10 @@ def on_ui_tabs(): single_image = gr.Image(type="pil") with gr.TabItem(label="Batch") as input_tab_batch: batch_image = gr.File(label="Batch Images", file_count="multiple", interactive=True, type="file") + with gr.TabItem(label="Batch from Dir") as input_tab_dir: + input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) + output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) + show_result = gr.Checkbox(label="Show result images", value=True) with gr.Accordion("Mask Setting", open=True): with gr.Tab("Segment Anything & CLIP"): sa_enabled = gr.Checkbox(label="enabled", show_label=True) @@ -97,12 +133,13 @@ def on_ui_tabs(): with gr.Column(): gallery = gr.Gallery(label="outputs", show_label=True, elem_id="gallery").style(grid=2) - # 0: single 1: batch + # 0: single 1: batch 2: batch dir input_tab_single.select(fn=lambda: 0, inputs=[], outputs=[input_tab_state]) input_tab_batch.select(fn=lambda: 1, inputs=[], outputs=[input_tab_state]) + input_tab_dir.select(fn=lambda: 2, inputs=[], outputs=[input_tab_state]) submit.click( processing, - inputs=[single_image, batch_image, input_tab_state, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold], + inputs=[single_image, batch_image, input_dir, output_dir, show_result, input_tab_state, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold], outputs=gallery )