Accept kohya-ss's metadata json as both input and output
parent
9db35c0a45
commit
a646ebd065
|
|
@ -1,6 +1,5 @@
|
|||
from pathlib import Path
|
||||
import re
|
||||
import glob
|
||||
from typing import List, Set, Optional
|
||||
from modules import shared
|
||||
from modules.textual_inversion.dataset import re_numbers_at_start
|
||||
|
|
@ -433,7 +432,7 @@ class DatasetTagEditor:
|
|||
print(e)
|
||||
|
||||
|
||||
def load_dataset(self, img_dir: str, caption_ext:str, recursive: bool, load_caption_from_filename: bool, interrogate_method: InterrogateMethod, interrogator_names: List[str], threshold_booru: float, threshold_waifu: float, use_temp_dir: bool):
|
||||
def load_dataset(self, img_dir: str, caption_ext:str, recursive: bool, load_caption_from_filename: bool, interrogate_method: InterrogateMethod, interrogator_names: List[str], threshold_booru: float, threshold_waifu: float, use_temp_dir: bool, kohya_json_path:Optional[str]):
|
||||
self.clear()
|
||||
|
||||
img_dir_obj = Path(img_dir)
|
||||
|
|
@ -454,11 +453,12 @@ class DatasetTagEditor:
|
|||
|
||||
print(f'[tag-editor] Total {len(filepaths)} files under the directory including not image files.')
|
||||
|
||||
def load_images(filepaths: List[Path], captionings: List[captioning.Captioning], taggers: List[tagger.Tagger]):
|
||||
def load_images(filepaths: List[Path]):
|
||||
imgpaths = []
|
||||
images = {}
|
||||
for img_path in filepaths:
|
||||
if img_path.suffix == caption_ext:
|
||||
continue
|
||||
|
||||
try:
|
||||
img = Image.open(img_path)
|
||||
except:
|
||||
|
|
@ -466,8 +466,14 @@ class DatasetTagEditor:
|
|||
else:
|
||||
if not use_temp_dir:
|
||||
img.already_saved_as = str(img_path.absolute())
|
||||
self.images[img_path] = img
|
||||
images[img_path] = img
|
||||
|
||||
imgpaths.append(img_path)
|
||||
return imgpaths, images
|
||||
|
||||
def load_captions(imgpaths: List[Path]):
|
||||
taglists = []
|
||||
for img_path in imgpaths:
|
||||
text_path = img_path.with_suffix(caption_ext)
|
||||
caption_text = ''
|
||||
if interrogate_method != InterrogateMethod.OVERWRITE:
|
||||
|
|
@ -481,24 +487,11 @@ class DatasetTagEditor:
|
|||
tokens = self.re_word.findall(caption_text)
|
||||
caption_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
||||
|
||||
interrogate_tags = []
|
||||
caption_tags = [t.strip() for t in caption_text.split(',')]
|
||||
caption_tags = [t for t in caption_tags if t]
|
||||
if interrogate_method != InterrogateMethod.NONE and ((interrogate_method != InterrogateMethod.PREFILL) or (interrogate_method == InterrogateMethod.PREFILL and not caption_tags)):
|
||||
img = img.convert('RGB')
|
||||
for cap in captionings:
|
||||
interrogate_tags += cap.predict(img)
|
||||
|
||||
for tg, threshold in taggers:
|
||||
interrogate_tags += [t for t in tg.predict(img, threshold).keys()]
|
||||
|
||||
if interrogate_method == InterrogateMethod.OVERWRITE:
|
||||
tags = interrogate_tags
|
||||
elif interrogate_method == InterrogateMethod.PREPEND:
|
||||
tags = interrogate_tags + caption_tags
|
||||
else:
|
||||
tags = caption_tags + interrogate_tags
|
||||
self.set_tags_by_image_path(img_path, tags)
|
||||
taglists.append(caption_tags)
|
||||
|
||||
return taglists
|
||||
|
||||
try:
|
||||
captionings = []
|
||||
|
|
@ -515,8 +508,34 @@ class DatasetTagEditor:
|
|||
elif isinstance(it, captioning.Captioning):
|
||||
captionings.append(it)
|
||||
|
||||
|
||||
load_images(filepaths=filepaths, captionings=captionings, taggers=taggers)
|
||||
if kohya_json_path:
|
||||
imgpaths, self.images, taglists = kohya_metadata.read(img_dir, kohya_json_path, use_temp_dir)
|
||||
else:
|
||||
imgpaths, self.images = load_images(filepaths)
|
||||
taglists = load_captions(imgpaths)
|
||||
|
||||
for img_path, tags in zip(imgpaths, taglists):
|
||||
interrogate_tags = []
|
||||
img = self.images.get(img_path)
|
||||
if interrogate_method != InterrogateMethod.NONE and ((interrogate_method != InterrogateMethod.PREFILL) or (interrogate_method == InterrogateMethod.PREFILL and not tags)):
|
||||
if img is None:
|
||||
print(f'Failed to load image {img_path}. Interrogating is aborted.')
|
||||
else:
|
||||
img = img.convert('RGB')
|
||||
for cap in captionings:
|
||||
interrogate_tags += cap.predict(img)
|
||||
|
||||
for tg, threshold in taggers:
|
||||
interrogate_tags += [t for t in tg.predict(img, threshold).keys()]
|
||||
|
||||
if interrogate_method == InterrogateMethod.OVERWRITE:
|
||||
tags = interrogate_tags
|
||||
elif interrogate_method == InterrogateMethod.PREPEND:
|
||||
tags = interrogate_tags + tags
|
||||
else:
|
||||
tags = tags + interrogate_tags
|
||||
|
||||
self.set_tags_by_image_path(img_path, tags)
|
||||
|
||||
finally:
|
||||
if interrogate_method != InterrogateMethod.NONE:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@
|
|||
# on commit hash: ae33d724793e14f16b4c68bdad79f836c86b1b8e
|
||||
|
||||
import json
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
def write(dataset, dataset_dir, out_path, in_path=None, overwrite=False, save_as_caption=False, use_full_path=False):
|
||||
dataset_dir = Path(dataset_dir)
|
||||
|
|
@ -38,4 +40,54 @@ def write(dataset, dataset_dir, out_path, in_path=None, overwrite=False, save_as
|
|||
result[img_key][tags_key] = save_caption
|
||||
|
||||
with open(out_path, 'w', encoding='utf-8', newline='') as f:
|
||||
json.dump(result, f, indent=2)
|
||||
json.dump(result, f, indent=2)
|
||||
|
||||
|
||||
def read(dataset_dir, json_path, use_temp_dir:bool):
|
||||
dataset_dir = Path(dataset_dir)
|
||||
json_path = Path(json_path)
|
||||
metadata = json.loads(json_path.read_text('utf8'))
|
||||
imgpaths = []
|
||||
images = {}
|
||||
taglists = []
|
||||
|
||||
for image_key, img_md in metadata.items():
|
||||
img_path = Path(image_key)
|
||||
if img_path.is_file():
|
||||
try:
|
||||
img = Image.open(img_path)
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
abs_path = str(img_path.absolute())
|
||||
if not use_temp_dir:
|
||||
img.already_saved_as = abs_path
|
||||
images[abs_path] = img
|
||||
else:
|
||||
try:
|
||||
for path in glob(str(dataset_dir.absolute() / (image_key + '.*'))):
|
||||
img_path = Path(path)
|
||||
try:
|
||||
img = Image.open(img_path)
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
abs_path = str(img_path.absolute())
|
||||
if not use_temp_dir:
|
||||
img.already_saved_as = abs_path
|
||||
images[abs_path] = img
|
||||
break
|
||||
except:
|
||||
continue
|
||||
|
||||
caption = img_md.get('caption')
|
||||
tags = img_md.get('tags')
|
||||
if tags is None:
|
||||
tags = []
|
||||
if caption is not None and isinstance(caption, str):
|
||||
caption = [s.strip() for s in caption.split(',')]
|
||||
tags = [s for s in caption if s] + tags
|
||||
imgpaths.append(abs_path)
|
||||
taglists.append(tags)
|
||||
|
||||
return imgpaths, images, taglists
|
||||
|
|
@ -211,7 +211,8 @@ def load_files_from_dir(
|
|||
use_custom_threshold_booru: bool,
|
||||
custom_threshold_booru: float,
|
||||
use_custom_threshold_waifu: bool,
|
||||
custom_threshold_waifu: float
|
||||
custom_threshold_waifu: float,
|
||||
kohya_json_path: str
|
||||
):
|
||||
global total_image_num, displayed_image_num, tmp_selection_img_path_set, gallery_selected_image_path, selection_selected_image_path, path_filter
|
||||
|
||||
|
|
@ -228,7 +229,7 @@ def load_files_from_dir(
|
|||
threshold_booru = custom_threshold_booru if use_custom_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold
|
||||
threshold_waifu = custom_threshold_waifu if use_custom_threshold_waifu else -1
|
||||
|
||||
dataset_tag_editor.load_dataset(dir, caption_file_ext, recursive, load_caption_from_filename, interrogate_method, use_interrogator_names, threshold_booru, threshold_waifu, opts.dataset_editor_use_temp_files)
|
||||
dataset_tag_editor.load_dataset(dir, caption_file_ext, recursive, load_caption_from_filename, interrogate_method, use_interrogator_names, threshold_booru, threshold_waifu, opts.dataset_editor_use_temp_files, kohya_json_path)
|
||||
imgs = dataset_tag_editor.get_filtered_imgs(filters=[])
|
||||
img_indices = dataset_tag_editor.get_filtered_imgindices(filters=[])
|
||||
path_filter = filters.PathFilter()
|
||||
|
|
@ -563,11 +564,11 @@ def on_ui_tabs():
|
|||
cb_backup = gr.Checkbox(value=cfg_general.backup, label='Backup original text file (original file will be renamed like filename.000, .001, .002, ...)', interactive=True)
|
||||
gr.HTML(value='<b>Note:</b> New text file will be created if you are using filename as captions.')
|
||||
with gr.Row():
|
||||
cb_save_kohya_metadata = gr.Checkbox(value=cfg_general.save_kohya_metadata, label="Save kohya-ss's finetuning metadata json", interactive=True)
|
||||
cb_save_kohya_metadata = gr.Checkbox(value=cfg_general.save_kohya_metadata, label="Use kohya-ss's finetuning metadata json", interactive=True)
|
||||
with gr.Row():
|
||||
with gr.Column(variant='panel', visible=cfg_general.save_kohya_metadata) as kohya_metadata:
|
||||
tb_metadata_output = gr.Textbox(label='json output path', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_output_path)
|
||||
tb_metadata_input = gr.Textbox(label='json input path (Optional)', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_input_path)
|
||||
tb_metadata_output = gr.Textbox(label='json path', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_output_path)
|
||||
tb_metadata_input = gr.Textbox(label='json input path (Optional, only for append results)', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_input_path)
|
||||
with gr.Row():
|
||||
cb_metadata_overwrite = gr.Checkbox(value=cfg_general.meta_overwrite, label="Overwrite if output file exists", interactive=True)
|
||||
cb_metadata_as_caption = gr.Checkbox(value=cfg_general.meta_save_as_caption, label="Save metadata as caption", interactive=True)
|
||||
|
|
@ -809,7 +810,7 @@ def on_ui_tabs():
|
|||
|
||||
btn_load_datasets.click(
|
||||
fn=load_files_from_dir,
|
||||
inputs=[tb_img_directory, tb_caption_file_ext, cb_load_recursive, cb_load_caption_from_filename, rb_use_interrogator, dd_intterogator_names, cb_use_custom_threshold_booru, sl_custom_threshold_booru, cb_use_custom_threshold_waifu, sl_custom_threshold_waifu],
|
||||
inputs=[tb_img_directory, tb_caption_file_ext, cb_load_recursive, cb_load_caption_from_filename, rb_use_interrogator, dd_intterogator_names, cb_use_custom_threshold_booru, sl_custom_threshold_booru, cb_use_custom_threshold_waifu, sl_custom_threshold_waifu, tb_metadata_output],
|
||||
outputs=
|
||||
[gl_dataset_images, gl_filter_images, txt_gallery, txt_selection] +
|
||||
[cbg_hidden_dataset_filter, nb_hidden_dataset_filter_apply] +
|
||||
|
|
|
|||
Loading…
Reference in New Issue