diff --git a/scripts/dataset_tag_editor/dataset_tag_editor.py b/scripts/dataset_tag_editor/dataset_tag_editor.py index fab22f7..604a3c4 100644 --- a/scripts/dataset_tag_editor/dataset_tag_editor.py +++ b/scripts/dataset_tag_editor/dataset_tag_editor.py @@ -1,6 +1,5 @@ from pathlib import Path import re -import glob from typing import List, Set, Optional from modules import shared from modules.textual_inversion.dataset import re_numbers_at_start @@ -433,7 +432,7 @@ class DatasetTagEditor: print(e) - def load_dataset(self, img_dir: str, caption_ext:str, recursive: bool, load_caption_from_filename: bool, interrogate_method: InterrogateMethod, interrogator_names: List[str], threshold_booru: float, threshold_waifu: float, use_temp_dir: bool): + def load_dataset(self, img_dir: str, caption_ext:str, recursive: bool, load_caption_from_filename: bool, interrogate_method: InterrogateMethod, interrogator_names: List[str], threshold_booru: float, threshold_waifu: float, use_temp_dir: bool, kohya_json_path:Optional[str]): self.clear() img_dir_obj = Path(img_dir) @@ -454,11 +453,12 @@ class DatasetTagEditor: print(f'[tag-editor] Total {len(filepaths)} files under the directory including not image files.') - def load_images(filepaths: List[Path], captionings: List[captioning.Captioning], taggers: List[tagger.Tagger]): + def load_images(filepaths: List[Path]): + imgpaths = [] + images = {} for img_path in filepaths: if img_path.suffix == caption_ext: continue - try: img = Image.open(img_path) except: @@ -466,8 +466,14 @@ class DatasetTagEditor: else: if not use_temp_dir: img.already_saved_as = str(img_path.absolute()) - self.images[img_path] = img + images[img_path] = img + imgpaths.append(img_path) + return imgpaths, images + + def load_captions(imgpaths: List[Path]): + taglists = [] + for img_path in imgpaths: text_path = img_path.with_suffix(caption_ext) caption_text = '' if interrogate_method != InterrogateMethod.OVERWRITE: @@ -481,24 +487,11 @@ class DatasetTagEditor: tokens = self.re_word.findall(caption_text) caption_text = (shared.opts.dataset_filename_join_string or "").join(tokens) - interrogate_tags = [] caption_tags = [t.strip() for t in caption_text.split(',')] caption_tags = [t for t in caption_tags if t] - if interrogate_method != InterrogateMethod.NONE and ((interrogate_method != InterrogateMethod.PREFILL) or (interrogate_method == InterrogateMethod.PREFILL and not caption_tags)): - img = img.convert('RGB') - for cap in captionings: - interrogate_tags += cap.predict(img) - - for tg, threshold in taggers: - interrogate_tags += [t for t in tg.predict(img, threshold).keys()] - - if interrogate_method == InterrogateMethod.OVERWRITE: - tags = interrogate_tags - elif interrogate_method == InterrogateMethod.PREPEND: - tags = interrogate_tags + caption_tags - else: - tags = caption_tags + interrogate_tags - self.set_tags_by_image_path(img_path, tags) + taglists.append(caption_tags) + + return taglists try: captionings = [] @@ -515,8 +508,34 @@ class DatasetTagEditor: elif isinstance(it, captioning.Captioning): captionings.append(it) - - load_images(filepaths=filepaths, captionings=captionings, taggers=taggers) + if kohya_json_path: + imgpaths, self.images, taglists = kohya_metadata.read(img_dir, kohya_json_path, use_temp_dir) + else: + imgpaths, self.images = load_images(filepaths) + taglists = load_captions(imgpaths) + + for img_path, tags in zip(imgpaths, taglists): + interrogate_tags = [] + img = self.images.get(img_path) + if interrogate_method != InterrogateMethod.NONE and ((interrogate_method != InterrogateMethod.PREFILL) or (interrogate_method == InterrogateMethod.PREFILL and not tags)): + if img is None: + print(f'Failed to load image {img_path}. Interrogating is aborted.') + else: + img = img.convert('RGB') + for cap in captionings: + interrogate_tags += cap.predict(img) + + for tg, threshold in taggers: + interrogate_tags += [t for t in tg.predict(img, threshold).keys()] + + if interrogate_method == InterrogateMethod.OVERWRITE: + tags = interrogate_tags + elif interrogate_method == InterrogateMethod.PREPEND: + tags = interrogate_tags + tags + else: + tags = tags + interrogate_tags + + self.set_tags_by_image_path(img_path, tags) finally: if interrogate_method != InterrogateMethod.NONE: diff --git a/scripts/dataset_tag_editor/kohya-ss_finetune_metadata.py b/scripts/dataset_tag_editor/kohya-ss_finetune_metadata.py index 1fbcf26..44ba8b4 100644 --- a/scripts/dataset_tag_editor/kohya-ss_finetune_metadata.py +++ b/scripts/dataset_tag_editor/kohya-ss_finetune_metadata.py @@ -10,7 +10,9 @@ # on commit hash: ae33d724793e14f16b4c68bdad79f836c86b1b8e import json +from glob import glob from pathlib import Path +from PIL import Image def write(dataset, dataset_dir, out_path, in_path=None, overwrite=False, save_as_caption=False, use_full_path=False): dataset_dir = Path(dataset_dir) @@ -38,4 +40,54 @@ def write(dataset, dataset_dir, out_path, in_path=None, overwrite=False, save_as result[img_key][tags_key] = save_caption with open(out_path, 'w', encoding='utf-8', newline='') as f: - json.dump(result, f, indent=2) \ No newline at end of file + json.dump(result, f, indent=2) + + +def read(dataset_dir, json_path, use_temp_dir:bool): + dataset_dir = Path(dataset_dir) + json_path = Path(json_path) + metadata = json.loads(json_path.read_text('utf8')) + imgpaths = [] + images = {} + taglists = [] + + for image_key, img_md in metadata.items(): + img_path = Path(image_key) + if img_path.is_file(): + try: + img = Image.open(img_path) + except: + continue + else: + abs_path = str(img_path.absolute()) + if not use_temp_dir: + img.already_saved_as = abs_path + images[abs_path] = img + else: + try: + for path in glob(str(dataset_dir.absolute() / (image_key + '.*'))): + img_path = Path(path) + try: + img = Image.open(img_path) + except: + continue + else: + abs_path = str(img_path.absolute()) + if not use_temp_dir: + img.already_saved_as = abs_path + images[abs_path] = img + break + except: + continue + + caption = img_md.get('caption') + tags = img_md.get('tags') + if tags is None: + tags = [] + if caption is not None and isinstance(caption, str): + caption = [s.strip() for s in caption.split(',')] + tags = [s for s in caption if s] + tags + imgpaths.append(abs_path) + taglists.append(tags) + + return imgpaths, images, taglists \ No newline at end of file diff --git a/scripts/main.py b/scripts/main.py index 8d49228..4c7561f 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -211,7 +211,8 @@ def load_files_from_dir( use_custom_threshold_booru: bool, custom_threshold_booru: float, use_custom_threshold_waifu: bool, - custom_threshold_waifu: float + custom_threshold_waifu: float, + kohya_json_path: str ): global total_image_num, displayed_image_num, tmp_selection_img_path_set, gallery_selected_image_path, selection_selected_image_path, path_filter @@ -228,7 +229,7 @@ def load_files_from_dir( threshold_booru = custom_threshold_booru if use_custom_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold threshold_waifu = custom_threshold_waifu if use_custom_threshold_waifu else -1 - dataset_tag_editor.load_dataset(dir, caption_file_ext, recursive, load_caption_from_filename, interrogate_method, use_interrogator_names, threshold_booru, threshold_waifu, opts.dataset_editor_use_temp_files) + dataset_tag_editor.load_dataset(dir, caption_file_ext, recursive, load_caption_from_filename, interrogate_method, use_interrogator_names, threshold_booru, threshold_waifu, opts.dataset_editor_use_temp_files, kohya_json_path) imgs = dataset_tag_editor.get_filtered_imgs(filters=[]) img_indices = dataset_tag_editor.get_filtered_imgindices(filters=[]) path_filter = filters.PathFilter() @@ -563,11 +564,11 @@ def on_ui_tabs(): cb_backup = gr.Checkbox(value=cfg_general.backup, label='Backup original text file (original file will be renamed like filename.000, .001, .002, ...)', interactive=True) gr.HTML(value='Note: New text file will be created if you are using filename as captions.') with gr.Row(): - cb_save_kohya_metadata = gr.Checkbox(value=cfg_general.save_kohya_metadata, label="Save kohya-ss's finetuning metadata json", interactive=True) + cb_save_kohya_metadata = gr.Checkbox(value=cfg_general.save_kohya_metadata, label="Use kohya-ss's finetuning metadata json", interactive=True) with gr.Row(): with gr.Column(variant='panel', visible=cfg_general.save_kohya_metadata) as kohya_metadata: - tb_metadata_output = gr.Textbox(label='json output path', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_output_path) - tb_metadata_input = gr.Textbox(label='json input path (Optional)', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_input_path) + tb_metadata_output = gr.Textbox(label='json path', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_output_path) + tb_metadata_input = gr.Textbox(label='json input path (Optional, only for append results)', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_input_path) with gr.Row(): cb_metadata_overwrite = gr.Checkbox(value=cfg_general.meta_overwrite, label="Overwrite if output file exists", interactive=True) cb_metadata_as_caption = gr.Checkbox(value=cfg_general.meta_save_as_caption, label="Save metadata as caption", interactive=True) @@ -809,7 +810,7 @@ def on_ui_tabs(): btn_load_datasets.click( fn=load_files_from_dir, - inputs=[tb_img_directory, tb_caption_file_ext, cb_load_recursive, cb_load_caption_from_filename, rb_use_interrogator, dd_intterogator_names, cb_use_custom_threshold_booru, sl_custom_threshold_booru, cb_use_custom_threshold_waifu, sl_custom_threshold_waifu], + inputs=[tb_img_directory, tb_caption_file_ext, cb_load_recursive, cb_load_caption_from_filename, rb_use_interrogator, dd_intterogator_names, cb_use_custom_threshold_booru, sl_custom_threshold_booru, cb_use_custom_threshold_waifu, sl_custom_threshold_waifu, tb_metadata_output], outputs= [gl_dataset_images, gl_filter_images, txt_gallery, txt_selection] + [cbg_hidden_dataset_filter, nb_hidden_dataset_filter_apply] +