implement batch dir mode
parent
aaf474083f
commit
4f07966fd6
|
|
@ -8,6 +8,8 @@ import gradio as gr
|
|||
|
||||
import modules.scripts as scripts
|
||||
from modules import script_callbacks
|
||||
import modules.shared as shared
|
||||
from modules import images
|
||||
|
||||
from scripts.td_abg import get_foreground
|
||||
from scripts.convertor import pil2cv
|
||||
|
|
@ -25,13 +27,13 @@ sam_model_dir = os.path.join(
|
|||
model_list = [f for f in os.listdir(sam_model_dir) if os.path.isfile(
|
||||
os.path.join(sam_model_dir, f)) and f.split('.')[-1] != 'txt']
|
||||
|
||||
def processing(single_image, batch_image, input_tab_state, *rem_args):
|
||||
def processing(single_image, batch_image, input_dir, output_dir, show_result, input_tab_state, *rem_args):
|
||||
# 0: single
|
||||
if (input_tab_state == 0):
|
||||
processed = process_image(single_image, *rem_args)
|
||||
return processed
|
||||
# 1 (or ohter): batch
|
||||
else:
|
||||
# 1: batch
|
||||
elif (input_tab_state == 1):
|
||||
processed = []
|
||||
for i in batch_image:
|
||||
image = Image.open(i)
|
||||
|
|
@ -39,6 +41,36 @@ def processing(single_image, batch_image, input_tab_state, *rem_args):
|
|||
processed.append(r[0])
|
||||
processed.append(r[1])
|
||||
return processed
|
||||
# 2: batch dir (or other)
|
||||
else:
|
||||
processed = []
|
||||
files = shared.listfiles(input_dir)
|
||||
for f in files:
|
||||
try:
|
||||
image = Image.open(f)
|
||||
except Exception:
|
||||
continue
|
||||
base, mask = process_image(image, *rem_args)
|
||||
processed.append(base)
|
||||
processed.append(mask)
|
||||
basename = os.path.splitext(os.path.basename(f))[0]
|
||||
ext = os.path.splitext(f)[1][1:]
|
||||
images.save_image(
|
||||
Image.fromarray(base),
|
||||
path=output_dir,
|
||||
basename=basename,
|
||||
extension=ext,
|
||||
)
|
||||
images.save_image(
|
||||
Image.fromarray(mask),
|
||||
path=output_dir,
|
||||
basename=basename+"_mask",
|
||||
extension=ext,
|
||||
)
|
||||
if (show_result):
|
||||
return processed
|
||||
else:
|
||||
return None
|
||||
|
||||
def process_image(target_image, *rem_args):
|
||||
image = pil2cv(target_image)
|
||||
|
|
@ -69,6 +101,10 @@ def on_ui_tabs():
|
|||
single_image = gr.Image(type="pil")
|
||||
with gr.TabItem(label="Batch") as input_tab_batch:
|
||||
batch_image = gr.File(label="Batch Images", file_count="multiple", interactive=True, type="file")
|
||||
with gr.TabItem(label="Batch from Dir") as input_tab_dir:
|
||||
input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
|
||||
output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
|
||||
show_result = gr.Checkbox(label="Show result images", value=True)
|
||||
with gr.Accordion("Mask Setting", open=True):
|
||||
with gr.Tab("Segment Anything & CLIP"):
|
||||
sa_enabled = gr.Checkbox(label="enabled", show_label=True)
|
||||
|
|
@ -97,12 +133,13 @@ def on_ui_tabs():
|
|||
with gr.Column():
|
||||
gallery = gr.Gallery(label="outputs", show_label=True, elem_id="gallery").style(grid=2)
|
||||
|
||||
# 0: single 1: batch
|
||||
# 0: single 1: batch 2: batch dir
|
||||
input_tab_single.select(fn=lambda: 0, inputs=[], outputs=[input_tab_state])
|
||||
input_tab_batch.select(fn=lambda: 1, inputs=[], outputs=[input_tab_state])
|
||||
input_tab_dir.select(fn=lambda: 2, inputs=[], outputs=[input_tab_state])
|
||||
submit.click(
|
||||
processing,
|
||||
inputs=[single_image, batch_image, input_tab_state, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold],
|
||||
inputs=[single_image, batch_image, input_dir, output_dir, show_result, input_tab_state, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold],
|
||||
outputs=gallery
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue