Accept kohya-ss's metadata json as both input and output

pull/48/head
toshiaki1729 2023-02-26 14:14:15 +09:00
parent 9db35c0a45
commit a646ebd065
3 changed files with 102 additions and 30 deletions

View File

@ -1,6 +1,5 @@
from pathlib import Path
import re
import glob
from typing import List, Set, Optional
from modules import shared
from modules.textual_inversion.dataset import re_numbers_at_start
@ -433,7 +432,7 @@ class DatasetTagEditor:
print(e)
def load_dataset(self, img_dir: str, caption_ext:str, recursive: bool, load_caption_from_filename: bool, interrogate_method: InterrogateMethod, interrogator_names: List[str], threshold_booru: float, threshold_waifu: float, use_temp_dir: bool):
def load_dataset(self, img_dir: str, caption_ext:str, recursive: bool, load_caption_from_filename: bool, interrogate_method: InterrogateMethod, interrogator_names: List[str], threshold_booru: float, threshold_waifu: float, use_temp_dir: bool, kohya_json_path:Optional[str]):
self.clear()
img_dir_obj = Path(img_dir)
@ -454,11 +453,12 @@ class DatasetTagEditor:
print(f'[tag-editor] Total {len(filepaths)} files under the directory including not image files.')
def load_images(filepaths: List[Path], captionings: List[captioning.Captioning], taggers: List[tagger.Tagger]):
def load_images(filepaths: List[Path]):
imgpaths = []
images = {}
for img_path in filepaths:
if img_path.suffix == caption_ext:
continue
try:
img = Image.open(img_path)
except:
@ -466,8 +466,14 @@ class DatasetTagEditor:
else:
if not use_temp_dir:
img.already_saved_as = str(img_path.absolute())
self.images[img_path] = img
images[img_path] = img
imgpaths.append(img_path)
return imgpaths, images
def load_captions(imgpaths: List[Path]):
taglists = []
for img_path in imgpaths:
text_path = img_path.with_suffix(caption_ext)
caption_text = ''
if interrogate_method != InterrogateMethod.OVERWRITE:
@ -481,24 +487,11 @@ class DatasetTagEditor:
tokens = self.re_word.findall(caption_text)
caption_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
interrogate_tags = []
caption_tags = [t.strip() for t in caption_text.split(',')]
caption_tags = [t for t in caption_tags if t]
if interrogate_method != InterrogateMethod.NONE and ((interrogate_method != InterrogateMethod.PREFILL) or (interrogate_method == InterrogateMethod.PREFILL and not caption_tags)):
img = img.convert('RGB')
for cap in captionings:
interrogate_tags += cap.predict(img)
for tg, threshold in taggers:
interrogate_tags += [t for t in tg.predict(img, threshold).keys()]
if interrogate_method == InterrogateMethod.OVERWRITE:
tags = interrogate_tags
elif interrogate_method == InterrogateMethod.PREPEND:
tags = interrogate_tags + caption_tags
else:
tags = caption_tags + interrogate_tags
self.set_tags_by_image_path(img_path, tags)
taglists.append(caption_tags)
return taglists
try:
captionings = []
@ -515,8 +508,34 @@ class DatasetTagEditor:
elif isinstance(it, captioning.Captioning):
captionings.append(it)
load_images(filepaths=filepaths, captionings=captionings, taggers=taggers)
if kohya_json_path:
imgpaths, self.images, taglists = kohya_metadata.read(img_dir, kohya_json_path, use_temp_dir)
else:
imgpaths, self.images = load_images(filepaths)
taglists = load_captions(imgpaths)
for img_path, tags in zip(imgpaths, taglists):
interrogate_tags = []
img = self.images.get(img_path)
if interrogate_method != InterrogateMethod.NONE and ((interrogate_method != InterrogateMethod.PREFILL) or (interrogate_method == InterrogateMethod.PREFILL and not tags)):
if img is None:
print(f'Failed to load image {img_path}. Interrogating is aborted.')
else:
img = img.convert('RGB')
for cap in captionings:
interrogate_tags += cap.predict(img)
for tg, threshold in taggers:
interrogate_tags += [t for t in tg.predict(img, threshold).keys()]
if interrogate_method == InterrogateMethod.OVERWRITE:
tags = interrogate_tags
elif interrogate_method == InterrogateMethod.PREPEND:
tags = interrogate_tags + tags
else:
tags = tags + interrogate_tags
self.set_tags_by_image_path(img_path, tags)
finally:
if interrogate_method != InterrogateMethod.NONE:

View File

@ -10,7 +10,9 @@
# on commit hash: ae33d724793e14f16b4c68bdad79f836c86b1b8e
import json
from glob import glob
from pathlib import Path
from PIL import Image
def write(dataset, dataset_dir, out_path, in_path=None, overwrite=False, save_as_caption=False, use_full_path=False):
dataset_dir = Path(dataset_dir)
@ -38,4 +40,54 @@ def write(dataset, dataset_dir, out_path, in_path=None, overwrite=False, save_as
result[img_key][tags_key] = save_caption
with open(out_path, 'w', encoding='utf-8', newline='') as f:
json.dump(result, f, indent=2)
json.dump(result, f, indent=2)
def read(dataset_dir, json_path, use_temp_dir:bool):
dataset_dir = Path(dataset_dir)
json_path = Path(json_path)
metadata = json.loads(json_path.read_text('utf8'))
imgpaths = []
images = {}
taglists = []
for image_key, img_md in metadata.items():
img_path = Path(image_key)
if img_path.is_file():
try:
img = Image.open(img_path)
except:
continue
else:
abs_path = str(img_path.absolute())
if not use_temp_dir:
img.already_saved_as = abs_path
images[abs_path] = img
else:
try:
for path in glob(str(dataset_dir.absolute() / (image_key + '.*'))):
img_path = Path(path)
try:
img = Image.open(img_path)
except:
continue
else:
abs_path = str(img_path.absolute())
if not use_temp_dir:
img.already_saved_as = abs_path
images[abs_path] = img
break
except:
continue
caption = img_md.get('caption')
tags = img_md.get('tags')
if tags is None:
tags = []
if caption is not None and isinstance(caption, str):
caption = [s.strip() for s in caption.split(',')]
tags = [s for s in caption if s] + tags
imgpaths.append(abs_path)
taglists.append(tags)
return imgpaths, images, taglists

View File

@ -211,7 +211,8 @@ def load_files_from_dir(
use_custom_threshold_booru: bool,
custom_threshold_booru: float,
use_custom_threshold_waifu: bool,
custom_threshold_waifu: float
custom_threshold_waifu: float,
kohya_json_path: str
):
global total_image_num, displayed_image_num, tmp_selection_img_path_set, gallery_selected_image_path, selection_selected_image_path, path_filter
@ -228,7 +229,7 @@ def load_files_from_dir(
threshold_booru = custom_threshold_booru if use_custom_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold
threshold_waifu = custom_threshold_waifu if use_custom_threshold_waifu else -1
dataset_tag_editor.load_dataset(dir, caption_file_ext, recursive, load_caption_from_filename, interrogate_method, use_interrogator_names, threshold_booru, threshold_waifu, opts.dataset_editor_use_temp_files)
dataset_tag_editor.load_dataset(dir, caption_file_ext, recursive, load_caption_from_filename, interrogate_method, use_interrogator_names, threshold_booru, threshold_waifu, opts.dataset_editor_use_temp_files, kohya_json_path)
imgs = dataset_tag_editor.get_filtered_imgs(filters=[])
img_indices = dataset_tag_editor.get_filtered_imgindices(filters=[])
path_filter = filters.PathFilter()
@ -563,11 +564,11 @@ def on_ui_tabs():
cb_backup = gr.Checkbox(value=cfg_general.backup, label='Backup original text file (original file will be renamed like filename.000, .001, .002, ...)', interactive=True)
gr.HTML(value='<b>Note:</b> New text file will be created if you are using filename as captions.')
with gr.Row():
cb_save_kohya_metadata = gr.Checkbox(value=cfg_general.save_kohya_metadata, label="Save kohya-ss's finetuning metadata json", interactive=True)
cb_save_kohya_metadata = gr.Checkbox(value=cfg_general.save_kohya_metadata, label="Use kohya-ss's finetuning metadata json", interactive=True)
with gr.Row():
with gr.Column(variant='panel', visible=cfg_general.save_kohya_metadata) as kohya_metadata:
tb_metadata_output = gr.Textbox(label='json output path', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_output_path)
tb_metadata_input = gr.Textbox(label='json input path (Optional)', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_input_path)
tb_metadata_output = gr.Textbox(label='json path', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_output_path)
tb_metadata_input = gr.Textbox(label='json input path (Optional, only for append results)', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_input_path)
with gr.Row():
cb_metadata_overwrite = gr.Checkbox(value=cfg_general.meta_overwrite, label="Overwrite if output file exists", interactive=True)
cb_metadata_as_caption = gr.Checkbox(value=cfg_general.meta_save_as_caption, label="Save metadata as caption", interactive=True)
@ -809,7 +810,7 @@ def on_ui_tabs():
btn_load_datasets.click(
fn=load_files_from_dir,
inputs=[tb_img_directory, tb_caption_file_ext, cb_load_recursive, cb_load_caption_from_filename, rb_use_interrogator, dd_intterogator_names, cb_use_custom_threshold_booru, sl_custom_threshold_booru, cb_use_custom_threshold_waifu, sl_custom_threshold_waifu],
inputs=[tb_img_directory, tb_caption_file_ext, cb_load_recursive, cb_load_caption_from_filename, rb_use_interrogator, dd_intterogator_names, cb_use_custom_threshold_booru, sl_custom_threshold_booru, cb_use_custom_threshold_waifu, sl_custom_threshold_waifu, tb_metadata_output],
outputs=
[gl_dataset_images, gl_filter_images, txt_gallery, txt_selection] +
[cbg_hidden_dataset_filter, nb_hidden_dataset_filter_apply] +