diff --git a/scripts/dataset_tag_editor/dataset_tag_editor.py b/scripts/dataset_tag_editor/dataset_tag_editor.py index 718a604..7f661c3 100644 --- a/scripts/dataset_tag_editor/dataset_tag_editor.py +++ b/scripts/dataset_tag_editor/dataset_tag_editor.py @@ -96,7 +96,6 @@ class DatasetTagEditor: self.img_idx = dict() self.tag_counts = {} self.dataset_dir = '' - self.caption_ext = '.txt' def get_tag_list(self): if len(self.tag_counts) == 0: @@ -403,13 +402,11 @@ class DatasetTagEditor: print(e) - def load_dataset(self, img_dir: str, caption_ext:str, recursive: bool, load_caption_from_filename: bool, interrogate_method: InterrogateMethod, use_booru: bool, use_blip: bool, use_git:bool, use_waifu: bool, threshold_booru: float, threshold_waifu: float): + def load_dataset(self, img_dir: str, caption_ext:str, recursive: bool, load_caption_from_filename: bool, interrogate_method: InterrogateMethod, use_booru: bool, use_blip: bool, use_git:bool, use_waifu: bool, threshold_booru: float, threshold_waifu: float): self.clear() print(f'Loading dataset from {img_dir}') if recursive: print(f'Also loading from subdirectories.') - - self.caption_ext = caption_ext try: filepath_set = get_filepath_set(dir=img_dir, recursive=recursive) @@ -426,7 +423,7 @@ class DatasetTagEditor: for img_path in filepath_set: img_dir = os.path.dirname(img_path) img_filename, img_ext = os.path.splitext(os.path.basename(img_path)) - if img_ext == self.caption_ext: + if img_ext == caption_ext: continue try: @@ -436,7 +433,7 @@ class DatasetTagEditor: else: img.close() - text_filename = os.path.join(img_dir, img_filename+self.caption_ext) + text_filename = os.path.join(img_dir, img_filename+caption_ext) caption_text = '' if interrogate_method != InterrogateMethod.OVERWRITE: # from modules/textual_inversion/dataset.py, modified @@ -514,7 +511,7 @@ class DatasetTagEditor: print(f'Loading Completed: {len(self.dataset)} images found') - def save_dataset(self, backup: bool): + def save_dataset(self, backup: bool, caption_ext: str): if len(self.dataset) == 0: return (0, 0, '') @@ -525,7 +522,7 @@ class DatasetTagEditor: img_path, tags = data.imgpath, data.tags img_dir = os.path.dirname(img_path) img_path_noext, _ = os.path.splitext(os.path.basename(img_path)) - txt_path = os.path.join(img_dir, img_path_noext + self.caption_ext) + txt_path = os.path.join(img_dir, img_path_noext + caption_ext) # make backup if backup and os.path.exists(txt_path) and os.path.isfile(txt_path): for extnum in range(1000): diff --git a/scripts/main.py b/scripts/main.py index 0371aaf..859a7c1 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -267,8 +267,8 @@ def update_common_tags(): return [tags, tags] -def save_all_changes(backup: bool): - saved, total, dir = dataset_tag_editor.save_dataset(backup=backup) +def save_all_changes(backup: bool, caption_ext: str): + saved, total, dir = dataset_tag_editor.save_dataset(backup, caption_ext) return f'Saved text files : {saved}/{total} under {dir}' if total > 0 else '' @@ -570,7 +570,7 @@ def on_ui_tabs(): with gr.Column(scale=3): tb_img_directory = gr.Textbox(label='Dataset directory', placeholder='C:\\directory\\of\\datasets', value=cfg_general.dataset_dir) with gr.Column(scale=1, min_width=60): - tb_caption_file_ext = gr.Textbox(label='Caption File Ext', placeholder='txt', value=cfg_general.caption_ext) + tb_caption_file_ext = gr.Textbox(label='Caption File Ext', placeholder='.txt (on Load and Save)', value=cfg_general.caption_ext) with gr.Column(scale=1, min_width=80): btn_load_datasets = gr.Button(value='Load') with gr.Accordion(label='Dataset Load Settings'): @@ -780,7 +780,7 @@ def on_ui_tabs(): btn_save_all_changes.click( fn=save_all_changes, - inputs=[cb_backup], + inputs=[cb_backup, tb_caption_file_ext], outputs=[txt_result] )