diff --git a/config example.toml b/config example.toml index 363e303..4b48e3e 100644 --- a/config example.toml +++ b/config example.toml @@ -135,3 +135,27 @@ sample_sampler = "euler_a" # Sampler to use for image sampling [sdxl] sdxl_cache_text_encoder_outputs = false # Cache text encoder outputs sdxl_no_half_vae = true # No half VAE + +[wd14_caption] +always_first_tags = "" # comma-separated list of tags to always put at the beginning, e.g. 1girl,1boy +append_tags = false # Append TAGs +batch_size = 8 # Batch size +caption_extension = ".txt" # Extension for caption file (e.g., .caption, .txt) +caption_separator = ", " # Caption Separator +character_tag_expand = false # Expand tag tail parenthesis to another tag for character tags. `chara_name_(series)` becomes `chara_name, series` +character_threshold = 0.35 # Character threshold +debug = false # Debug mode +frequency_tags = false # Frequency tags +force_download = false # Force model re-download when switching to onnx +general_threshold = 0.35 # General threshold +max_data_loader_n_workers = 2 # Max dataloader workers +onnx = true # ONNX +recursive = false # Recursive +remove_underscore = false # Remove underscore +repo_id = "SmilingWolf/wd-convnext-tagger-v3" # Repo id for wd14 tagger on Hugging Face +tag_replacement = "" # Tag replacement in the format of `source1,target1;source2,target2; ...`. Escape `,` and `;` with `\`. e.g. `tag1,tag2;tag3,tag4` +thresh = 0.36 # Threshold +train_data_dir = "" # Image folder to caption (containing the images to caption) +undesired_tags = "" # comma-separated list of tags to remove, e.g. 1girl,1boy +use_rating_tags = false # Use rating tags +use_rating_tags_as_last_tag = false # Use rating tags as last tagging tags diff --git a/kohya_gui.py b/kohya_gui.py index cb98ef9..1b77b46 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -64,6 +64,7 @@ def UI(**kwargs): logging_dir_input=logging_dir_input, enable_copy_info_button=True, headless=headless, + config=config, ) with gr.Tab("LoRA"): _ = LoRATools(headless=headless) diff --git a/kohya_gui/utilities.py b/kohya_gui/utilities.py index 8592277..408ce15 100644 --- a/kohya_gui/utilities.py +++ b/kohya_gui/utilities.py @@ -18,13 +18,14 @@ def utilities_tab( enable_copy_info_button=bool(False), enable_dreambooth_tab=True, headless=False, + config: dict = {}, ): with gr.Tab("Captioning"): gradio_basic_caption_gui_tab(headless=headless) gradio_blip_caption_gui_tab(headless=headless) gradio_blip2_caption_gui_tab(headless=headless) gradio_git_caption_gui_tab(headless=headless) - gradio_wd14_caption_gui_tab(headless=headless) + gradio_wd14_caption_gui_tab(headless=headless, config=config) gradio_manual_caption_gui_tab(headless=headless) gradio_convert_model_tab(headless=headless) gradio_group_images_gui_tab(headless=headless) diff --git a/kohya_gui/wd14_caption_gui.py b/kohya_gui/wd14_caption_gui.py index 07eae85..7369bee 100644 --- a/kohya_gui/wd14_caption_gui.py +++ b/kohya_gui/wd14_caption_gui.py @@ -2,6 +2,7 @@ import gradio as gr from easygui import msgbox import subprocess from .common_gui import get_folder_path, scriptdir, list_dirs +from .class_gui_config import KohyaSSGUIConfig import os from .custom_logging import setup_logging @@ -30,7 +31,7 @@ def caption_images( tag_replacement: bool, character_tag_expand: str, use_rating_tags: bool, - use_ratuse_rating_tags_as_last_taging_tags: bool, + use_rating_tags_as_last_tag: bool, remove_underscore: bool, thresh: float, ) -> None: @@ -53,15 +54,17 @@ def caption_images( run_cmd += f' --caption_extension="{caption_extension}"' run_cmd += f' --caption_separator="{caption_separator}"' if character_tag_expand: - run_cmd += f' --character_tag_expand="{character_tag_expand}"' - run_cmd += f" --character_threshold={character_threshold}" + run_cmd += f" --character_tag_expand" + if not character_threshold == 0.35: + run_cmd += f" --character_threshold={character_threshold}" if debug: run_cmd += f" --debug" if force_download: run_cmd += f" --force_download" if frequency_tags: run_cmd += f" --frequency_tags" - run_cmd += f" --general_threshold={general_threshold}" + if not general_threshold == 0.35: + run_cmd += f" --general_threshold={general_threshold}" run_cmd += f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"' if onnx: run_cmd += f" --onnx" @@ -72,14 +75,14 @@ def caption_images( run_cmd += f' --repo_id="{repo_id}"' if tag_replacement: run_cmd += f" --tag_replacement" - if thresh: + if not thresh == 0.35: run_cmd += f" --thresh={thresh}" if not undesired_tags == "": run_cmd += f' --undesired_tags="{undesired_tags}"' if use_rating_tags: run_cmd += f" --use_rating_tags" - if use_ratuse_rating_tags_as_last_taging_tags: - run_cmd += f" --use_ratuse_rating_tags_as_last_taging_tags" + if use_rating_tags_as_last_tag: + run_cmd += f" --use_rating_tags_as_last_tag" run_cmd += rf' "{train_data_dir}"' log.info(run_cmd) @@ -101,7 +104,9 @@ def caption_images( ### -def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None): +def gradio_wd14_caption_gui_tab( + headless=False, default_train_dir=None, config: KohyaSSGUIConfig = {} +): from .common_gui import create_refresh_button default_train_dir = ( @@ -126,8 +131,9 @@ def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None): with gr.Group(), gr.Row(): train_data_dir = gr.Dropdown( label="Image folder to caption (containing the images to caption)", - choices=[""] + list_train_dirs(default_train_dir), - value="", + choices=[config.get("wd14_caption.train_data_dir", "")] + + list_train_dirs(default_train_dir), + value=config.get("wd14_caption.train_data_dir", ""), interactive=True, allow_custom_value=True, ) @@ -148,7 +154,7 @@ def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None): outputs=train_data_dir, show_progress=False, ) - + repo_id = gr.Dropdown( label="Repo ID", choices=[ @@ -157,48 +163,50 @@ def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None): "SmilingWolf/wd-v1-4-vit-tagger-v2", "SmilingWolf/wd-v1-4-swinv2-tagger-v2", "SmilingWolf/wd-v1-4-moat-tagger-v2", - 'SmilingWolf/wd-swinv2-tagger-v3', - 'SmilingWolf/wd-vit-tagger-v3', - 'SmilingWolf/wd-convnext-tagger-v3', + "SmilingWolf/wd-swinv2-tagger-v3", + "SmilingWolf/wd-vit-tagger-v3", + "SmilingWolf/wd-convnext-tagger-v3", ], - value="SmilingWolf/wd-v1-4-convnextv2-tagger-v2", + value=config.get( + "wd14_caption.repo_id", "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" + ), show_label="Repo id for wd14 tagger on Hugging Face", ) force_download = gr.Checkbox( label="Force model re-download", - value=False, + value=config.get("wd14_caption.force_download", False), info="Useful to force model re download when switching to onnx", ) - + with gr.Row(): caption_extension = gr.Textbox( label="Caption file extension", placeholder="Extension for caption file (e.g., .caption, .txt)", - value=".txt", + value=config.get("wd14_caption.caption_extension", ".txt"), interactive=True, ) caption_separator = gr.Textbox( label="Caption Separator", - value=",", + value=config.get("wd14_caption.caption_separator", ", "), interactive=True, ) - + with gr.Row(): - + tag_replacement = gr.Textbox( label="Tag replacement", info="tag replacement in the format of `source1,target1;source2,target2; ...`. Escape `,` and `;` with `\`. e.g. `tag1,tag2;tag3,tag4`", - value="", + value=config.get("wd14_caption.tag_replacement", ""), interactive=True, ) - - character_tag_expand = gr.Textbox( + + character_tag_expand = gr.Checkbox( label="Character tag expand", info="expand tag tail parenthesis to another tag for character tags. `chara_name_(series)` becomes `chara_name, series`", - value="", + value=config.get("wd14_caption.character_tag_expand", False), interactive=True, ) @@ -206,6 +214,7 @@ def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None): label="Undesired tags", placeholder="(Optional) Separate `undesired_tags` with comma `(,)` if you want to remove multiple tags, e.g. `1girl,solo,smile`.", interactive=True, + value=config.get("wd14_caption.undesired_tags", ""), ) with gr.Row(): @@ -214,32 +223,35 @@ def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None): info="comma-separated list of tags to always put at the beginning, e.g. 1girl,1boy", placeholder="(Optional)", interactive=True, + value=config.get("wd14_caption.always_first_tags", ""), ) with gr.Row(): onnx = gr.Checkbox( label="Use onnx", - value=True, + value=config.get("wd14_caption.onnx", True), interactive=True, info="https://github.com/onnx/onnx", ) append_tags = gr.Checkbox( label="Append TAGs", - value=False, + value=config.get("wd14_caption.append_tags", False), interactive=True, info="This option appends the tags to the existing tags, instead of replacing them.", ) - + use_rating_tags = gr.Checkbox( label="Use rating tags", - value=False, + value=config.get("wd14_caption.use_rating_tags", False), interactive=True, info="Adds rating tags as the first tag", ) - - use_ratuse_rating_tags_as_last_taging_tags = gr.Checkbox( + + use_rating_tags_as_last_tag = gr.Checkbox( label="Use rating tags as last tag", - value=False, + value=config.get( + "wd14_caption.use_rating_tags_as_last_tag", False + ), interactive=True, info="Adds rating tags as the last tag", ) @@ -247,38 +259,38 @@ def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None): with gr.Row(): recursive = gr.Checkbox( label="Recursive", - value=False, + value=config.get("wd14_caption.recursive", False), info="Tag subfolders images as well", ) remove_underscore = gr.Checkbox( label="Remove underscore", - value=True, + value=config.get("wd14_caption.remove_underscore", True), info="replace underscores with spaces in the output tags", ) debug = gr.Checkbox( label="Debug", - value=True, + value=config.get("wd14_caption.debug", True), info="Debug mode", ) frequency_tags = gr.Checkbox( label="Show tags frequency", - value=True, + value=config.get("wd14_caption.frequency_tags", True), info="Show frequency of tags for images.", ) - + with gr.Row(): thresh = gr.Slider( - value=0.35, + value=config.get("wd14_caption.thresh", 0.35), label="Threshold", info="threshold of confidence to add a tag", minimum=0, maximum=1, step=0.05, ) - + general_threshold = gr.Slider( - value=0.35, + value=config.get("wd14_caption.general_threshold", 0.35), label="General threshold", info="Adjust `general_threshold` for pruning tags (less tags, less flexible)", minimum=0, @@ -286,7 +298,7 @@ def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None): step=0.05, ) character_threshold = gr.Slider( - value=0.35, + value=config.get("wd14_caption.character_threshold", 0.35), label="Character threshold", minimum=0, maximum=1, @@ -295,10 +307,16 @@ def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None): # Advanced Settings with gr.Row(): - batch_size = gr.Number(value=8, label="Batch size", interactive=True) + batch_size = gr.Number( + value=config.get("wd14_caption.batch_size", 8), + label="Batch size", + interactive=True, + ) max_data_loader_n_workers = gr.Number( - value=2, label="Max dataloader workers", interactive=True + value=config.get("wd14_caption.max_data_loader_n_workers", 2), + label="Max dataloader workers", + interactive=True, ) caption_button = gr.Button("Caption images") @@ -325,7 +343,7 @@ def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None): tag_replacement, character_tag_expand, use_rating_tags, - use_ratuse_rating_tags_as_last_taging_tags, + use_rating_tags_as_last_tag, remove_underscore, thresh, ],