loading from subdirectories

pull/3/head
toshiaki1729 2022-11-03 22:36:46 +09:00
parent c533db7ae7
commit 9e86a17671
2 changed files with 34 additions and 14 deletions

View File

@ -6,6 +6,21 @@ from modules.textual_inversion.dataset import re_numbers_at_start
re_tags = re.compile(r'^(.+) \[\d+\]$')
def get_filepath_set(dir: str, recursive: bool) -> set[str]:
basenames = os.listdir(dir)
paths = {os.path.join(dir, basename) for basename in basenames}
if recursive:
result = set()
for path in paths:
if os.path.isdir(path):
result = result | get_filepath_set(path, True)
elif os.path.isfile(path):
result.add(path)
return result
else:
return {path for path in paths if os.path.isfile(path)}
class DatasetTagEditor:
def __init__(self):
# from modules.textual_inversion.dataset
@ -14,6 +29,7 @@ class DatasetTagEditor:
self.img_tag_set_dict = {}
self.tag_counts = {}
self.img_filter_img_path_set = set()
self.dataset_dir = ''
def get_tag_list(self) -> list[str]:
@ -125,17 +141,19 @@ class DatasetTagEditor:
return {k for k in self.img_tag_dict.keys() if k}
def load_dataset(self, img_dir: str):
def load_dataset(self, img_dir: str, recursive: bool = False):
self.clear()
try:
f_list = os.listdir(img_dir)
filepath_set = get_filepath_set(dir=img_dir, recursive=recursive)
except:
return
for img_filebasename in f_list:
img_path = os.path.join(img_dir, img_filebasename)
self.dataset_dir = img_dir
for img_path in filepath_set:
img_dir = os.path.dirname(img_path)
img_filename, img_ext = os.path.splitext(img_filebasename)
if os.path.isfile(img_path) and (img_ext == '.png'):
img_filename, img_ext = os.path.splitext(os.path.basename(img_path))
if img_ext == '.png':
text_filename = os.path.join(img_dir, img_filename+'.txt')
# from modules/textual_inversion/dataset.py
if os.path.exists(text_filename) and os.path.isfile(text_filename):
@ -193,9 +211,9 @@ class DatasetTagEditor:
else:
saved_num += 1
print(f'Backup text files: {backup_num}/{len(self.img_tag_dict)} in {img_dir}')
print(f'Saved text files: {saved_num}/{len(self.img_tag_dict)} in {img_dir}')
return (saved_num, len(self.img_tag_dict), img_dir)
print(f'Backup text files: {backup_num}/{len(self.img_tag_dict)} under {self.dataset_dir}')
print(f'Saved text files: {saved_num}/{len(self.img_tag_dict)} under {self.dataset_dir}')
return (saved_num, len(self.img_tag_dict), self.dataset_dir)
def clear(self):
@ -203,6 +221,7 @@ class DatasetTagEditor:
self.img_tag_set_dict.clear()
self.tag_counts.clear()
self.img_filter_img_path_set.clear()
self.dataset_dir = ''
def construct_tag_counts(self):

View File

@ -24,9 +24,9 @@ def arrange_tag_order(tags: List[str], sort_by: str, sort_order: str) -> List[st
return tags
def load_files_from_dir(dir: str, sort_by: str, sort_order: str):
def load_files_from_dir(dir: str, sort_by: str, sort_order: str, recursive: bool):
global total_image_num, displayed_image_num, current_tag_filter, current_selection, tmp_selection_img_path_set, selected_image_path
dataset_tag_editor.load_dataset(dir)
dataset_tag_editor.load_dataset(img_dir=dir, recursive=recursive)
img_paths, tags = dataset_tag_editor.get_filtered_imgpath_and_tags()
tags = arrange_tag_order(tags=tags, sort_by=sort_by, sort_order=sort_order)
total_image_num = displayed_image_num = current_selection = len(dataset_tag_editor.get_img_path_set())
@ -101,7 +101,7 @@ def apply_edit_tags(edit_tags: str, filter_tags: List[str], append_to_begin: boo
def save_all_changes(backup: bool) -> str:
saved, total, dir = dataset_tag_editor.save_dataset(backup=backup)
return f'Saved text files : {saved}/{total} in {dir}' if total > 0 else ''
return f'Saved text files : {saved}/{total} under {dir}' if total > 0 else ''
# ================================================================
@ -249,10 +249,11 @@ def on_ui_tabs():
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=3):
tb_img_directory = gr.Textbox(label='Dataset directory', placeholder='C:\\directory\\of\\datasets')
with gr.Column(scale=1, min_width=80):
btn_load_datasets = gr.Button(value='Load')
cb_load_recursive = gr.Checkbox(value=False, label='Load from subdirectories')
gl_dataset_images = gr.Gallery(label='Dataset Images', elem_id="dataset_tag_editor_images_gallery").style(grid=opts.dataset_editor_image_columns)
txt_filter = gr.HTML(value=f"""
Displayed Images : {displayed_image_num} / {total_image_num} total<br>
@ -338,7 +339,7 @@ def on_ui_tabs():
btn_load_datasets.click(
fn=load_files_from_dir,
inputs=[tb_img_directory, rd_sort_by, rd_sort_order],
inputs=[tb_img_directory, rd_sort_by, rd_sort_order, cb_load_recursive],
outputs=[gl_dataset_images, cbg_tags, tb_search_tags, txt_filter]
)
btn_load_datasets.click(