kohya-ss's finetuning metadata json file

pull/41/head
toshiaki1729 2023-02-05 02:43:25 +09:00
parent 5962f4f0bc
commit 3acaaffb6a
3 changed files with 80 additions and 13 deletions

View File

@ -12,6 +12,7 @@ ds = dynamic_import('scripts/dataset_tag_editor/dataset.py')
tagger = dynamic_import('scripts/dataset_tag_editor/tagger.py')
captioning = dynamic_import('scripts/dataset_tag_editor/captioning.py')
filters = dynamic_import('scripts/dataset_tag_editor/filters.py')
kohya_metadata = dynamic_import('scripts/dataset_tag_editor/kohya-ss_finetune_metadata.py')
re_tags = re.compile(r'^(.+) \[\d+\]$')
@ -421,7 +422,7 @@ class DatasetTagEditor:
try:
load_dir = glob.escape(os.path.abspath(img_dir))
filepath_set = [p for p in glob.glob(os.path.join(load_dir, '**'), recursive=recursive) if os.path.isfile(p)]
filepaths = [p for p in glob.glob(os.path.join(load_dir, '**'), recursive=recursive) if os.path.isfile(p)]
except Exception as e:
print(e)
print('[tag-editor] Loading Aborted.')
@ -429,10 +430,10 @@ class DatasetTagEditor:
self.dataset_dir = img_dir
print(f'[tag-editor] Total {len(filepath_set)} files under the directory including not image files.')
print(f'[tag-editor] Total {len(filepaths)} files under the directory including not image files.')
def load_images(filepath_set: Set[str], captionings: List[captioning.Captioning], taggers: List[tagger.Tagger]):
for img_path in filepath_set:
def load_images(filepaths: List[str], captionings: List[captioning.Captioning], taggers: List[tagger.Tagger]):
for img_path in filepaths:
img_dir = os.path.dirname(img_path)
img_filename, img_ext = os.path.splitext(os.path.basename(img_path))
if img_ext == caption_ext:
@ -502,7 +503,7 @@ class DatasetTagEditor:
captionings.append(it)
load_images(filepath_set=filepath_set, captionings=captionings, taggers=taggers)
load_images(filepaths=filepaths, captionings=captionings, taggers=taggers)
finally:
if interrogate_method != InterrogateMethod.NONE:
@ -518,7 +519,7 @@ class DatasetTagEditor:
print(f'[tag-editor] Loading Completed: {len(self.dataset)} images found')
def save_dataset(self, backup: bool, caption_ext: str):
def save_dataset(self, backup: bool, caption_ext: str, write_kohya_metadata: bool, meta_out_path: str, meta_in_path: Optional[str], meta_overwrite:bool, meta_as_caption: bool, meta_full_path: bool):
if len(self.dataset) == 0:
return (0, 0, '')
@ -557,9 +558,12 @@ class DatasetTagEditor:
print(f"[tag-editor] Warning: {txt_path} cannot be saved.")
else:
saved_num += 1
print(f'[tag-editor] Backup text files: {backup_num}/{len(self.dataset)} under {self.dataset_dir}')
print(f'[tag-editor] Saved text files: {saved_num}/{len(self.dataset)} under {self.dataset_dir}')
if(write_kohya_metadata):
kohya_metadata.write(dataset=self.dataset, dataset_dir=self.dataset_dir, out_path=meta_out_path, in_path=meta_in_path, overwrite=meta_overwrite, save_as_caption=meta_as_caption, use_full_path=meta_full_path)
print(f'[tag-editor] Saved json metadata file in {meta_out_path}')
return (saved_num, len(self.dataset), self.dataset_dir)

View File

@ -0,0 +1,34 @@
import json
from pathlib import Path
# implement metadata output compatible to kohya-ss's finetuning captions
# https://github.com/kohya-ss/sd-scripts/blob/main/finetune/merge_captions_to_metadata.py
# https://github.com/kohya-ss/sd-scripts/blob/main/finetune/merge_dd_tags_to_metadata.py
# on commit hash: ae33d724793e14f16b4c68bdad79f836c86b1b8e
def write(dataset, dataset_dir, out_path, in_path=None, overwrite=False, save_as_caption=False, use_full_path=False):
dataset_dir = Path(dataset_dir)
if in_path is None and Path(out_path).is_file() and not overwrite:
in_path = out_path
result = {}
if in_path is not None:
try:
result = json.loads(Path(in_path).read_text(encoding='utf-8'))
except:
result = {}
tags_key = 'caption' if save_as_caption else 'tags'
for data in dataset.datas.values():
img_path, tags = Path(data.imgpath), data.tags
img_key = str(img_path) if use_full_path else img_path.stem
save_caption = ', '.join(tags) if save_as_caption else tags
if img_key not in result:
result[img_key] = {}
result[img_key][tags_key] = save_caption
with open(out_path, 'w', encoding='utf-8', newline='') as f:
json.dump(result, f, indent=2)

View File

@ -48,14 +48,20 @@ GeneralConfig = namedtuple('GeneralConfig', [
'use_custom_threshold_booru',
'custom_threshold_booru',
'use_custom_threshold_waifu',
'custom_threshold_waifu'
'custom_threshold_waifu',
'save_kohya_metadata',
'meta_output_path',
'meta_input_path',
'meta_overwrite',
'meta_save_as_caption',
'meta_use_full_path'
])
FilterConfig = namedtuple('FilterConfig', ['sort_by', 'sort_order', 'logic'])
BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend', 'use_regex', 'target', 'sory_by', 'sort_order'])
EditSelectedConfig = namedtuple('EditSelectedConfig', ['auto_copy', 'warn_change_not_saved', 'use_interrogator_name'])
MoveDeleteConfig = namedtuple('MoveDeleteConfig', ['range', 'target', 'caption_ext', 'destination'])
CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, 'No', [], False, 0.7, False, 0.5)
CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, 'No', [], False, 0.7, False, 0.5, False, '', '', True, False, False)
CFG_FILTER_P_DEFAULT = FilterConfig('Alphabetical Order', 'Ascending', 'AND')
CFG_FILTER_N_DEFAULT = FilterConfig('Alphabetical Order', 'Ascending', 'OR')
CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', 'Alphabetical Order', 'Ascending')
@ -278,8 +284,10 @@ def update_common_tags():
return [tags, tags]
def save_all_changes(backup: bool, caption_ext: str):
saved, total, dir = dataset_tag_editor.save_dataset(backup, caption_ext)
def save_all_changes(backup: bool, caption_ext: str, save_kohya_metadata:bool, metadata_output:str, metadata_input:str, metadata_overwrite:bool, metadata_as_caption:bool, metadata_use_fullpath:bool):
if not metadata_input:
metadata_input = None
saved, total, dir = dataset_tag_editor.save_dataset(backup, caption_ext, save_kohya_metadata, metadata_output, metadata_input, metadata_overwrite, metadata_as_caption, metadata_use_fullpath)
return f'Saved text files : {saved}/{total} under {dir}' if total > 0 else ''
@ -549,6 +557,16 @@ def on_ui_tabs():
with gr.Column(scale=2):
cb_backup = gr.Checkbox(value=cfg_general.backup, label='Backup original text file (original file will be renamed like filename.000, .001, .002, ...)', interactive=True)
gr.HTML(value='<b>Note:</b> New text file will be created if you are using filename as captions.')
with gr.Row():
cb_save_kohya_metadata = gr.Checkbox(value=cfg_general.save_kohya_metadata, label="Save kohya-ss's finetuning metadata json", interactive=True)
with gr.Row():
with gr.Column(variant='panel', visible=cfg_general.save_kohya_metadata) as kohya_metadata:
tb_metadata_output = gr.Textbox(label='json output path', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_output_path)
tb_metadata_input = gr.Textbox(label='json input path (Optional)', placeholder='C:\\path\\to\\metadata.json',value=cfg_general.meta_input_path)
with gr.Row():
cb_metadata_overwrite = gr.Checkbox(value=cfg_general.meta_overwrite, label="Overwrite if output file exists", interactive=True)
cb_metadata_as_caption = gr.Checkbox(value=cfg_general.meta_save_as_caption, label="Save metadata as caption", interactive=True)
cb_metadata_use_fullpath = gr.Checkbox(value=cfg_general.meta_use_full_path, label="Save metadata image key as fullpath", interactive=True)
with gr.Row(visible=False):
txt_result = gr.Textbox(label='Results', interactive=False)
@ -713,7 +731,12 @@ def on_ui_tabs():
#----------------------------------------------------------------
# General
components_general = [cb_backup, tb_img_directory, tb_caption_file_ext, cb_load_recursive, cb_load_caption_from_filename, rb_use_interrogator, dd_intterogator_names, cb_use_custom_threshold_booru, sl_custom_threshold_booru, cb_use_custom_threshold_waifu, sl_custom_threshold_waifu]
components_general = [
cb_backup, tb_img_directory, tb_caption_file_ext, cb_load_recursive,
cb_load_caption_from_filename, rb_use_interrogator, dd_intterogator_names,
cb_use_custom_threshold_booru, sl_custom_threshold_booru, cb_use_custom_threshold_waifu, sl_custom_threshold_waifu,
cb_save_kohya_metadata, tb_metadata_output, tb_metadata_input, cb_metadata_overwrite, cb_metadata_as_caption, cb_metadata_use_fullpath
]
components_filter = [tag_filter_ui.rb_sort_by, tag_filter_ui.rb_sort_order, tag_filter_ui.rb_logic, tag_filter_ui_neg.rb_sort_by, tag_filter_ui_neg.rb_sort_order, tag_filter_ui_neg.rb_logic]
components_batch_edit = [cb_show_only_tags_selected, cb_prepend_tags, cb_use_regex, rb_sr_replace_target, tag_select_ui_remove.rb_sort_by, tag_select_ui_remove.rb_sort_order]
components_edit_selected = [cb_copy_caption_automatically, cb_ask_save_when_caption_changed, dd_intterogator_names_si]
@ -768,7 +791,7 @@ def on_ui_tabs():
btn_save_all_changes.click(
fn=save_all_changes,
inputs=[cb_backup, tb_caption_file_ext],
inputs=[cb_backup, tb_caption_file_ext, cb_save_kohya_metadata, tb_metadata_output, tb_metadata_input, cb_metadata_overwrite, cb_metadata_as_caption, cb_metadata_use_fullpath],
outputs=[txt_result]
)
@ -785,6 +808,12 @@ def on_ui_tabs():
outputs=[tb_common_tags, tb_edit_tags, tb_caption_selected_image, nb_hidden_image_index, nb_hidden_image_index_prev, nb_hidden_image_index_save_or_not]
)
cb_save_kohya_metadata.change(
fn=lambda x:gr.update(visible=x),
inputs=cb_save_kohya_metadata,
outputs=kohya_metadata
)
#----------------------------------------------------------------
# Filter by Tags tab