Merge pull request #33 from Hidetoshi-Iizawa/dir_out1

Change directory and file name handling during batch processing
pull/35/head
mattya_monaca 2023-05-19 08:29:56 +09:00 committed by GitHub
commit dcd00660df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 17 deletions

39
scripts/batch_dir.py Normal file
View File

@ -0,0 +1,39 @@
import os
import re
from PIL import Image
def save_image_dir(image, path, basename, extension='png'):
# Ensure the directory exists
os.makedirs(path, exist_ok=True)
# Generate the filename
filename = f"{basename}.{extension}"
full_path = os.path.join(path, filename)
# Save the image
image.save(full_path)
return full_path
def modify_basename(basename):
match = re.search(r'(\d+)(\.\w+)?$', basename)
if match is not None:
# If there is a sequence of digits followed by the file extension,
# capture the prefix, the sequence number, and the extension separately.
prefix = basename[:match.start()]
sequence = match.group(1)
extension = match.group(2) if match.group(2) else ''
# If there's a hyphen or underscore just before the sequence number,
# include it in the new name.
if prefix and (prefix[-1] == '_' or prefix[-1] == '-'):
separator = prefix[-1]
return f"{prefix[:-1]}{separator}mask{separator}{sequence}{extension}"
else:
return f"{prefix}_mask{sequence}{extension}"
else:
# If there's no sequence number, use the last character of the string to decide the format.
if basename and (basename[-1] == '_' or basename[-1] == '-'):
return f"{basename}mask"
else:
return f"{basename}-mask"

View File

@ -3,6 +3,7 @@ import io
import json
import numpy as np
import cv2
import re
import gradio as gr
@ -13,6 +14,8 @@ from modules import images
from scripts.td_abg import get_foreground
from scripts.convertor import pil2cv
from scripts.batch_dir import save_image_dir, modify_basename
try:
from modules.paths_internal import extensions_dir
except Exception:
@ -27,7 +30,7 @@ sam_model_dir = os.path.join(
model_list = [f for f in os.listdir(sam_model_dir) if os.path.isfile(
os.path.join(sam_model_dir, f)) and f.split('.')[-1] != 'txt']
def processing(single_image, batch_image, input_dir, output_dir, show_result, input_tab_state, *rem_args):
def processing(single_image, batch_image, input_dir, output_dir, output_mask_dir, show_result, input_tab_state, *rem_args):
# 0: single
if (input_tab_state == 0):
processed = process_image(single_image, *rem_args)
@ -53,20 +56,23 @@ def processing(single_image, batch_image, input_dir, output_dir, show_result, in
base, mask = process_image(image, *rem_args)
processed.append(base)
processed.append(mask)
basename = os.path.splitext(os.path.basename(f))[0]
ext = os.path.splitext(f)[1][1:]
images.save_image(
Image.fromarray(base),
path=output_dir,
basename=basename,
extension=ext,
)
images.save_image(
Image.fromarray(mask),
path=output_dir,
basename=basename+"-mask",
extension=ext,
)
if output_dir != "":
basename = os.path.splitext(os.path.basename(f))[0]
ext = os.path.splitext(f)[1][1:]
save_image_dir(
Image.fromarray(base),
path=output_dir,
basename=basename,
extension="png",
)
if output_mask_dir != "":
basename = modify_basename(basename)
save_image_dir(
Image.fromarray(mask),
path=output_mask_dir,
basename=basename,
extension="png",
)
if (show_result):
return processed
else:
@ -104,6 +110,7 @@ def on_ui_tabs():
with gr.TabItem(label="Batch from Dir") as input_tab_dir:
input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
output_mask_dir = gr.Textbox(label="Output Mask directory", **shared.hide_dirs)
show_result = gr.Checkbox(label="Show result images", value=True)
with gr.Accordion("Mask Setting", open=True):
with gr.Tab("Segment Anything & CLIP"):
@ -139,10 +146,10 @@ def on_ui_tabs():
input_tab_dir.select(fn=lambda: 2, inputs=[], outputs=[input_tab_state])
submit.click(
processing,
inputs=[single_image, batch_image, input_dir, output_dir, show_result, input_tab_state, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold],
inputs=[single_image, batch_image, input_dir, output_dir, output_mask_dir, show_result, input_tab_state, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_rate, cascadePSP_enabled, fast, psp_L, sa_enabled, seg_query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold],
outputs=gallery
)
return [(PBRemTools, "PBRemTools", "pbremtools")]
script_callbacks.on_ui_tabs(on_ui_tabs)