loading from subdirectories
parent
c533db7ae7
commit
9e86a17671
|
|
@ -6,6 +6,21 @@ from modules.textual_inversion.dataset import re_numbers_at_start
|
|||
|
||||
re_tags = re.compile(r'^(.+) \[\d+\]$')
|
||||
|
||||
def get_filepath_set(dir: str, recursive: bool) -> set[str]:
|
||||
basenames = os.listdir(dir)
|
||||
paths = {os.path.join(dir, basename) for basename in basenames}
|
||||
if recursive:
|
||||
result = set()
|
||||
for path in paths:
|
||||
if os.path.isdir(path):
|
||||
result = result | get_filepath_set(path, True)
|
||||
elif os.path.isfile(path):
|
||||
result.add(path)
|
||||
return result
|
||||
else:
|
||||
return {path for path in paths if os.path.isfile(path)}
|
||||
|
||||
|
||||
class DatasetTagEditor:
|
||||
def __init__(self):
|
||||
# from modules.textual_inversion.dataset
|
||||
|
|
@ -14,6 +29,7 @@ class DatasetTagEditor:
|
|||
self.img_tag_set_dict = {}
|
||||
self.tag_counts = {}
|
||||
self.img_filter_img_path_set = set()
|
||||
self.dataset_dir = ''
|
||||
|
||||
|
||||
def get_tag_list(self) -> list[str]:
|
||||
|
|
@ -125,17 +141,19 @@ class DatasetTagEditor:
|
|||
return {k for k in self.img_tag_dict.keys() if k}
|
||||
|
||||
|
||||
def load_dataset(self, img_dir: str):
|
||||
def load_dataset(self, img_dir: str, recursive: bool = False):
|
||||
self.clear()
|
||||
try:
|
||||
f_list = os.listdir(img_dir)
|
||||
filepath_set = get_filepath_set(dir=img_dir, recursive=recursive)
|
||||
except:
|
||||
return
|
||||
for img_filebasename in f_list:
|
||||
img_path = os.path.join(img_dir, img_filebasename)
|
||||
|
||||
self.dataset_dir = img_dir
|
||||
|
||||
for img_path in filepath_set:
|
||||
img_dir = os.path.dirname(img_path)
|
||||
img_filename, img_ext = os.path.splitext(img_filebasename)
|
||||
if os.path.isfile(img_path) and (img_ext == '.png'):
|
||||
img_filename, img_ext = os.path.splitext(os.path.basename(img_path))
|
||||
if img_ext == '.png':
|
||||
text_filename = os.path.join(img_dir, img_filename+'.txt')
|
||||
# from modules/textual_inversion/dataset.py
|
||||
if os.path.exists(text_filename) and os.path.isfile(text_filename):
|
||||
|
|
@ -193,9 +211,9 @@ class DatasetTagEditor:
|
|||
else:
|
||||
saved_num += 1
|
||||
|
||||
print(f'Backup text files: {backup_num}/{len(self.img_tag_dict)} in {img_dir}')
|
||||
print(f'Saved text files: {saved_num}/{len(self.img_tag_dict)} in {img_dir}')
|
||||
return (saved_num, len(self.img_tag_dict), img_dir)
|
||||
print(f'Backup text files: {backup_num}/{len(self.img_tag_dict)} under {self.dataset_dir}')
|
||||
print(f'Saved text files: {saved_num}/{len(self.img_tag_dict)} under {self.dataset_dir}')
|
||||
return (saved_num, len(self.img_tag_dict), self.dataset_dir)
|
||||
|
||||
|
||||
def clear(self):
|
||||
|
|
@ -203,6 +221,7 @@ class DatasetTagEditor:
|
|||
self.img_tag_set_dict.clear()
|
||||
self.tag_counts.clear()
|
||||
self.img_filter_img_path_set.clear()
|
||||
self.dataset_dir = ''
|
||||
|
||||
|
||||
def construct_tag_counts(self):
|
||||
|
|
|
|||
|
|
@ -24,9 +24,9 @@ def arrange_tag_order(tags: List[str], sort_by: str, sort_order: str) -> List[st
|
|||
return tags
|
||||
|
||||
|
||||
def load_files_from_dir(dir: str, sort_by: str, sort_order: str):
|
||||
def load_files_from_dir(dir: str, sort_by: str, sort_order: str, recursive: bool):
|
||||
global total_image_num, displayed_image_num, current_tag_filter, current_selection, tmp_selection_img_path_set, selected_image_path
|
||||
dataset_tag_editor.load_dataset(dir)
|
||||
dataset_tag_editor.load_dataset(img_dir=dir, recursive=recursive)
|
||||
img_paths, tags = dataset_tag_editor.get_filtered_imgpath_and_tags()
|
||||
tags = arrange_tag_order(tags=tags, sort_by=sort_by, sort_order=sort_order)
|
||||
total_image_num = displayed_image_num = current_selection = len(dataset_tag_editor.get_img_path_set())
|
||||
|
|
@ -101,7 +101,7 @@ def apply_edit_tags(edit_tags: str, filter_tags: List[str], append_to_begin: boo
|
|||
|
||||
def save_all_changes(backup: bool) -> str:
|
||||
saved, total, dir = dataset_tag_editor.save_dataset(backup=backup)
|
||||
return f'Saved text files : {saved}/{total} in {dir}' if total > 0 else ''
|
||||
return f'Saved text files : {saved}/{total} under {dir}' if total > 0 else ''
|
||||
|
||||
|
||||
# ================================================================
|
||||
|
|
@ -249,10 +249,11 @@ def on_ui_tabs():
|
|||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
with gr.Column(scale=3):
|
||||
tb_img_directory = gr.Textbox(label='Dataset directory', placeholder='C:\\directory\\of\\datasets')
|
||||
with gr.Column(scale=1, min_width=80):
|
||||
btn_load_datasets = gr.Button(value='Load')
|
||||
cb_load_recursive = gr.Checkbox(value=False, label='Load from subdirectories')
|
||||
gl_dataset_images = gr.Gallery(label='Dataset Images', elem_id="dataset_tag_editor_images_gallery").style(grid=opts.dataset_editor_image_columns)
|
||||
txt_filter = gr.HTML(value=f"""
|
||||
Displayed Images : {displayed_image_num} / {total_image_num} total<br>
|
||||
|
|
@ -338,7 +339,7 @@ def on_ui_tabs():
|
|||
|
||||
btn_load_datasets.click(
|
||||
fn=load_files_from_dir,
|
||||
inputs=[tb_img_directory, rd_sort_by, rd_sort_order],
|
||||
inputs=[tb_img_directory, rd_sort_by, rd_sort_order, cb_load_recursive],
|
||||
outputs=[gl_dataset_images, cbg_tags, tb_search_tags, txt_filter]
|
||||
)
|
||||
btn_load_datasets.click(
|
||||
|
|
|
|||
Loading…
Reference in New Issue