From 9e86a176719ff331280157eba55a02fcc30c2f8a Mon Sep 17 00:00:00 2001 From: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com> Date: Thu, 3 Nov 2022 22:36:46 +0900 Subject: [PATCH] loading from subdirectories --- dataset_tag_editor.py | 37 ++++++++++++++++++++++++-------- scripts/dataset_tag_editor_ui.py | 11 +++++----- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/dataset_tag_editor.py b/dataset_tag_editor.py index b574379..3891046 100644 --- a/dataset_tag_editor.py +++ b/dataset_tag_editor.py @@ -6,6 +6,21 @@ from modules.textual_inversion.dataset import re_numbers_at_start re_tags = re.compile(r'^(.+) \[\d+\]$') +def get_filepath_set(dir: str, recursive: bool) -> set[str]: + basenames = os.listdir(dir) + paths = {os.path.join(dir, basename) for basename in basenames} + if recursive: + result = set() + for path in paths: + if os.path.isdir(path): + result = result | get_filepath_set(path, True) + elif os.path.isfile(path): + result.add(path) + return result + else: + return {path for path in paths if os.path.isfile(path)} + + class DatasetTagEditor: def __init__(self): # from modules.textual_inversion.dataset @@ -14,6 +29,7 @@ class DatasetTagEditor: self.img_tag_set_dict = {} self.tag_counts = {} self.img_filter_img_path_set = set() + self.dataset_dir = '' def get_tag_list(self) -> list[str]: @@ -125,17 +141,19 @@ class DatasetTagEditor: return {k for k in self.img_tag_dict.keys() if k} - def load_dataset(self, img_dir: str): + def load_dataset(self, img_dir: str, recursive: bool = False): self.clear() try: - f_list = os.listdir(img_dir) + filepath_set = get_filepath_set(dir=img_dir, recursive=recursive) except: return - for img_filebasename in f_list: - img_path = os.path.join(img_dir, img_filebasename) + + self.dataset_dir = img_dir + + for img_path in filepath_set: img_dir = os.path.dirname(img_path) - img_filename, img_ext = os.path.splitext(img_filebasename) - if os.path.isfile(img_path) and (img_ext == '.png'): + img_filename, img_ext = os.path.splitext(os.path.basename(img_path)) + if img_ext == '.png': text_filename = os.path.join(img_dir, img_filename+'.txt') # from modules/textual_inversion/dataset.py if os.path.exists(text_filename) and os.path.isfile(text_filename): @@ -193,9 +211,9 @@ class DatasetTagEditor: else: saved_num += 1 - print(f'Backup text files: {backup_num}/{len(self.img_tag_dict)} in {img_dir}') - print(f'Saved text files: {saved_num}/{len(self.img_tag_dict)} in {img_dir}') - return (saved_num, len(self.img_tag_dict), img_dir) + print(f'Backup text files: {backup_num}/{len(self.img_tag_dict)} under {self.dataset_dir}') + print(f'Saved text files: {saved_num}/{len(self.img_tag_dict)} under {self.dataset_dir}') + return (saved_num, len(self.img_tag_dict), self.dataset_dir) def clear(self): @@ -203,6 +221,7 @@ class DatasetTagEditor: self.img_tag_set_dict.clear() self.tag_counts.clear() self.img_filter_img_path_set.clear() + self.dataset_dir = '' def construct_tag_counts(self): diff --git a/scripts/dataset_tag_editor_ui.py b/scripts/dataset_tag_editor_ui.py index 24a4d47..e59e6af 100644 --- a/scripts/dataset_tag_editor_ui.py +++ b/scripts/dataset_tag_editor_ui.py @@ -24,9 +24,9 @@ def arrange_tag_order(tags: List[str], sort_by: str, sort_order: str) -> List[st return tags -def load_files_from_dir(dir: str, sort_by: str, sort_order: str): +def load_files_from_dir(dir: str, sort_by: str, sort_order: str, recursive: bool): global total_image_num, displayed_image_num, current_tag_filter, current_selection, tmp_selection_img_path_set, selected_image_path - dataset_tag_editor.load_dataset(dir) + dataset_tag_editor.load_dataset(img_dir=dir, recursive=recursive) img_paths, tags = dataset_tag_editor.get_filtered_imgpath_and_tags() tags = arrange_tag_order(tags=tags, sort_by=sort_by, sort_order=sort_order) total_image_num = displayed_image_num = current_selection = len(dataset_tag_editor.get_img_path_set()) @@ -101,7 +101,7 @@ def apply_edit_tags(edit_tags: str, filter_tags: List[str], append_to_begin: boo def save_all_changes(backup: bool) -> str: saved, total, dir = dataset_tag_editor.save_dataset(backup=backup) - return f'Saved text files : {saved}/{total} in {dir}' if total > 0 else '' + return f'Saved text files : {saved}/{total} under {dir}' if total > 0 else '' # ================================================================ @@ -249,10 +249,11 @@ def on_ui_tabs(): with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): with gr.Row(): - with gr.Column(scale=4): + with gr.Column(scale=3): tb_img_directory = gr.Textbox(label='Dataset directory', placeholder='C:\\directory\\of\\datasets') with gr.Column(scale=1, min_width=80): btn_load_datasets = gr.Button(value='Load') + cb_load_recursive = gr.Checkbox(value=False, label='Load from subdirectories') gl_dataset_images = gr.Gallery(label='Dataset Images', elem_id="dataset_tag_editor_images_gallery").style(grid=opts.dataset_editor_image_columns) txt_filter = gr.HTML(value=f""" Displayed Images : {displayed_image_num} / {total_image_num} total
@@ -338,7 +339,7 @@ def on_ui_tabs(): btn_load_datasets.click( fn=load_files_from_dir, - inputs=[tb_img_directory, rd_sort_by, rd_sort_order], + inputs=[tb_img_directory, rd_sort_by, rd_sort_order, cb_load_recursive], outputs=[gl_dataset_images, cbg_tags, tb_search_tags, txt_filter] ) btn_load_datasets.click(