use caption file extension setting both on load and save

pull/41/head
toshiaki1729 2023-01-15 16:57:36 +09:00
parent 289802fc5c
commit 8931563fa4
2 changed files with 9 additions and 12 deletions

View File

@ -96,7 +96,6 @@ class DatasetTagEditor:
self.img_idx = dict()
self.tag_counts = {}
self.dataset_dir = ''
self.caption_ext = '.txt'
def get_tag_list(self):
if len(self.tag_counts) == 0:
@ -403,13 +402,11 @@ class DatasetTagEditor:
print(e)
def load_dataset(self, img_dir: str, caption_ext:str, recursive: bool, load_caption_from_filename: bool, interrogate_method: InterrogateMethod, use_booru: bool, use_blip: bool, use_git:bool, use_waifu: bool, threshold_booru: float, threshold_waifu: float):
def load_dataset(self, img_dir: str, caption_ext:str, recursive: bool, load_caption_from_filename: bool, interrogate_method: InterrogateMethod, use_booru: bool, use_blip: bool, use_git:bool, use_waifu: bool, threshold_booru: float, threshold_waifu: float):
self.clear()
print(f'Loading dataset from {img_dir}')
if recursive:
print(f'Also loading from subdirectories.')
self.caption_ext = caption_ext
try:
filepath_set = get_filepath_set(dir=img_dir, recursive=recursive)
@ -426,7 +423,7 @@ class DatasetTagEditor:
for img_path in filepath_set:
img_dir = os.path.dirname(img_path)
img_filename, img_ext = os.path.splitext(os.path.basename(img_path))
if img_ext == self.caption_ext:
if img_ext == caption_ext:
continue
try:
@ -436,7 +433,7 @@ class DatasetTagEditor:
else:
img.close()
text_filename = os.path.join(img_dir, img_filename+self.caption_ext)
text_filename = os.path.join(img_dir, img_filename+caption_ext)
caption_text = ''
if interrogate_method != InterrogateMethod.OVERWRITE:
# from modules/textual_inversion/dataset.py, modified
@ -514,7 +511,7 @@ class DatasetTagEditor:
print(f'Loading Completed: {len(self.dataset)} images found')
def save_dataset(self, backup: bool):
def save_dataset(self, backup: bool, caption_ext: str):
if len(self.dataset) == 0:
return (0, 0, '')
@ -525,7 +522,7 @@ class DatasetTagEditor:
img_path, tags = data.imgpath, data.tags
img_dir = os.path.dirname(img_path)
img_path_noext, _ = os.path.splitext(os.path.basename(img_path))
txt_path = os.path.join(img_dir, img_path_noext + self.caption_ext)
txt_path = os.path.join(img_dir, img_path_noext + caption_ext)
# make backup
if backup and os.path.exists(txt_path) and os.path.isfile(txt_path):
for extnum in range(1000):

View File

@ -267,8 +267,8 @@ def update_common_tags():
return [tags, tags]
def save_all_changes(backup: bool):
saved, total, dir = dataset_tag_editor.save_dataset(backup=backup)
def save_all_changes(backup: bool, caption_ext: str):
saved, total, dir = dataset_tag_editor.save_dataset(backup, caption_ext)
return f'Saved text files : {saved}/{total} under {dir}' if total > 0 else ''
@ -570,7 +570,7 @@ def on_ui_tabs():
with gr.Column(scale=3):
tb_img_directory = gr.Textbox(label='Dataset directory', placeholder='C:\\directory\\of\\datasets', value=cfg_general.dataset_dir)
with gr.Column(scale=1, min_width=60):
tb_caption_file_ext = gr.Textbox(label='Caption File Ext', placeholder='txt', value=cfg_general.caption_ext)
tb_caption_file_ext = gr.Textbox(label='Caption File Ext', placeholder='.txt (on Load and Save)', value=cfg_general.caption_ext)
with gr.Column(scale=1, min_width=80):
btn_load_datasets = gr.Button(value='Load')
with gr.Accordion(label='Dataset Load Settings'):
@ -780,7 +780,7 @@ def on_ui_tabs():
btn_save_all_changes.click(
fn=save_all_changes,
inputs=[cb_backup],
inputs=[cb_backup, tb_caption_file_ext],
outputs=[txt_result]
)