diff --git a/scripts/__pycache__/convertor.cpython-310.pyc b/scripts/__pycache__/convertor.cpython-310.pyc deleted file mode 100644 index b0e678d..0000000 Binary files a/scripts/__pycache__/convertor.cpython-310.pyc and /dev/null differ diff --git a/scripts/__pycache__/td_abg.cpython-310.pyc b/scripts/__pycache__/td_abg.cpython-310.pyc deleted file mode 100644 index fbd51b6..0000000 Binary files a/scripts/__pycache__/td_abg.cpython-310.pyc and /dev/null differ diff --git a/scripts/batch_dir.py b/scripts/batch_dir.py new file mode 100644 index 0000000..3970e79 --- /dev/null +++ b/scripts/batch_dir.py @@ -0,0 +1,39 @@ +import os +import re +from PIL import Image + +def save_image_dir(image, path, basename, extension='png'): + # Ensure the directory exists + os.makedirs(path, exist_ok=True) + + # Generate the filename + filename = f"{basename}.{extension}" + full_path = os.path.join(path, filename) + + # Save the image + image.save(full_path) + + return full_path + +def modify_basename(basename): + match = re.search(r'(\d+)(\.\w+)?$', basename) + if match is not None: + # If there is a sequence of digits followed by the file extension, + # capture the prefix, the sequence number, and the extension separately. + prefix = basename[:match.start()] + sequence = match.group(1) + extension = match.group(2) if match.group(2) else '' + + # If there's a hyphen or underscore just before the sequence number, + # include it in the new name. + if prefix and (prefix[-1] == '_' or prefix[-1] == '-'): + separator = prefix[-1] + return f"{prefix[:-1]}{separator}mask{separator}{sequence}{extension}" + else: + return f"{prefix}_mask{sequence}{extension}" + else: + # If there's no sequence number, use the last character of the string to decide the format. + if basename and (basename[-1] == '_' or basename[-1] == '-'): + return f"{basename}mask" + else: + return f"{basename}-mask" \ No newline at end of file diff --git a/scripts/main.py b/scripts/main.py index ece489a..0453bfb 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -3,6 +3,7 @@ import io import json import numpy as np import cv2 +import re import gradio as gr @@ -13,6 +14,8 @@ from modules import images from scripts.td_abg import get_foreground from scripts.convertor import pil2cv +from scripts.batch_dir import save_image_dir, modify_basename + try: from modules.paths_internal import extensions_dir except Exception: @@ -27,7 +30,7 @@ sam_model_dir = os.path.join( model_list = [f for f in os.listdir(sam_model_dir) if os.path.isfile( os.path.join(sam_model_dir, f)) and f.split('.')[-1] != 'txt'] -def processing(single_image, batch_image, input_dir, output_dir, show_result, input_tab_state, *rem_args): +def processing(single_image, batch_image, input_dir, output_dir, output_mask_dir, show_result, input_tab_state, *rem_args): # 0: single if (input_tab_state == 0): processed = process_image(single_image, *rem_args) @@ -53,20 +56,23 @@ def processing(single_image, batch_image, input_dir, output_dir, show_result, in base, mask = process_image(image, *rem_args) processed.append(base) processed.append(mask) - basename = os.path.splitext(os.path.basename(f))[0] - ext = os.path.splitext(f)[1][1:] - images.save_image( - Image.fromarray(base), - path=output_dir, - basename=basename, - extension=ext, - ) - images.save_image( - Image.fromarray(mask), - path=output_dir, - basename=basename+"-mask", - extension=ext, - ) + if output_dir != "": + basename = os.path.splitext(os.path.basename(f))[0] + ext = os.path.splitext(f)[1][1:] + save_image_dir( + Image.fromarray(base), + path=output_dir, + basename=basename, + extension="png", + ) + if output_mask_dir != "": + basename = modify_basename(basename) + save_image_dir( + Image.fromarray(mask), + path=output_mask_dir, + basename=basename, + extension="png", + ) if (show_result): return processed else: @@ -104,6 +110,7 @@ def on_ui_tabs(): with gr.TabItem(label="Batch from Dir") as input_tab_dir: input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) + output_mask_dir = gr.Textbox(label="Output Mask directory", **shared.hide_dirs) show_result = gr.Checkbox(label="Show result images", value=True) with gr.Accordion("Mask Setting", open=True): with gr.Tab("Segment Anything & CLIP"): @@ -139,10 +146,10 @@ def on_ui_tabs(): input_tab_dir.select(fn=lambda: 2, inputs=[], outputs=[input_tab_state]) submit.click( processing, - inputs=[single_image, batch_image, input_dir, output_dir, show_result, input_tab_state, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold], + inputs=[single_image, batch_image, input_dir, output_dir, output_mask_dir, show_result, input_tab_state, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold], outputs=gallery ) return [(PBRemTools, "PBRemTools", "pbremtools")] - + script_callbacks.on_ui_tabs(on_ui_tabs)