kohya_ss/kohya_gui/wd14_caption_gui.py

316 lines
10 KiB
Python

import gradio as gr
from easygui import msgbox
import subprocess
from .common_gui import get_folder_path, add_pre_postfix, scriptdir, list_dirs
import os
from .custom_logging import setup_logging
# Set up logging
log = setup_logging()
def caption_images(
train_data_dir: str,
caption_extension: str,
batch_size: int,
general_threshold: float,
character_threshold: float,
replace_underscores: bool,
repo_id: str,
recursive: bool,
max_data_loader_n_workers: int,
debug: bool,
undesired_tags: str,
frequency_tags: bool,
prefix: str,
postfix: str,
onnx: bool,
append_tags: bool,
force_download: bool,
caption_separator: str,
) -> None:
"""
Captions images in a given directory using the WD14 model.
Args:
train_data_dir (str): The directory containing the images to be captioned.
caption_extension (str): The extension to be used for the caption files.
batch_size (int): The batch size for the captioning process.
general_threshold (float): The general threshold for the captioning process.
character_threshold (float): The character threshold for the captioning process.
replace_underscores (bool): Whether to replace underscores in filenames with spaces.
repo_id (str): The ID of the repository containing the WD14 model.
recursive (bool): Whether to process subdirectories recursively.
max_data_loader_n_workers (int): The maximum number of workers for the data loader.
debug (bool): Whether to enable debug mode.
undesired_tags (str): Comma-separated list of tags to be removed from the captions.
frequency_tags (bool): Whether to include frequency tags in the captions.
prefix (str): The prefix to be added to the captions.
postfix (str): The postfix to be added to the captions.
onnx (bool): Whether to use ONNX for the captioning process.
append_tags (bool): Whether to append tags to existing tags.
force_download (bool): Whether to force the model to be downloaded.
caption_separator (str): The separator to be used for the captions.
"""
# Check for images_dir_input
if train_data_dir == "":
msgbox("Image folder is missing...")
return
if caption_extension == "":
msgbox("Please provide an extension for the caption files.")
return
log.info(f"Captioning files in {train_data_dir}...")
run_cmd = rf'accelerate launch "{scriptdir}/sd-scripts/finetune/tag_images_by_wd14_tagger.py"'
if append_tags:
run_cmd += f" --append_tags"
run_cmd += f" --batch_size={int(batch_size)}"
run_cmd += f' --caption_extension="{caption_extension}"'
run_cmd += f' --caption_separator="{caption_separator}"'
run_cmd += f" --character_threshold={character_threshold}"
if debug:
run_cmd += f" --debug"
if force_download:
run_cmd += f" --force_download"
if frequency_tags:
run_cmd += f" --frequency_tags"
run_cmd += f" --general_threshold={general_threshold}"
run_cmd += f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
if onnx:
run_cmd += f" --onnx"
if recursive:
run_cmd += f" --recursive"
if replace_underscores:
run_cmd += f" --remove_underscore"
run_cmd += f' --repo_id="{repo_id}"'
if not undesired_tags == "":
run_cmd += f' --undesired_tags="{undesired_tags}"'
run_cmd += rf' "{train_data_dir}"'
log.info(run_cmd)
env = os.environ.copy()
env["PYTHONPATH"] = (
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
# Run the command
subprocess.run(run_cmd, shell=True, env=env)
# Add prefix and postfix
add_pre_postfix(
folder=train_data_dir,
caption_file_ext=caption_extension,
prefix=prefix,
postfix=postfix,
)
log.info("...captioning done")
###
# Gradio UI
###
def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None):
from .common_gui import create_refresh_button
default_train_dir = (
default_train_dir
if default_train_dir is not None
else os.path.join(scriptdir, "data")
)
current_train_dir = default_train_dir
def list_train_dirs(path):
nonlocal current_train_dir
current_train_dir = path
return list(list_dirs(path))
with gr.Tab("WD14 Captioning"):
gr.Markdown(
"This utility will use WD14 to caption files for each images in a folder."
)
# Input Settings
# with gr.Section('Input Settings'):
with gr.Group(), gr.Row():
train_data_dir = gr.Dropdown(
label="Image folder to caption (containing the images to caption)",
choices=[""] + list_train_dirs(default_train_dir),
value="",
interactive=True,
allow_custom_value=True,
)
create_refresh_button(
train_data_dir,
lambda: None,
lambda: {"choices": list_train_dirs(current_train_dir)},
"open_folder_small",
)
button_train_data_dir_input = gr.Button(
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
button_train_data_dir_input.click(
get_folder_path,
outputs=train_data_dir,
show_progress=False,
)
caption_extension = gr.Textbox(
label='Caption file extension',
placeholder='Extension for caption file (e.g., .caption, .txt)',
value='.txt',
interactive=True,
)
caption_separator = gr.Textbox(
label="Caption Separator",
value=",",
interactive=True,
)
undesired_tags = gr.Textbox(
label="Undesired tags",
placeholder="(Optional) Separate `undesired_tags` with comma `(,)` if you want to remove multiple tags, e.g. `1girl,solo,smile`.",
interactive=True,
)
with gr.Row():
prefix = gr.Textbox(
label="Prefix to add to WD14 caption",
placeholder="(Optional)",
interactive=True,
)
postfix = gr.Textbox(
label="Postfix to add to WD14 caption",
placeholder="(Optional)",
interactive=True,
)
with gr.Row():
onnx = gr.Checkbox(
label="Use onnx",
value=False,
interactive=True,
info="https://github.com/onnx/onnx",
)
append_tags = gr.Checkbox(
label="Append TAGs",
value=False,
interactive=True,
info="This option appends the tags to the existing tags, instead of replacing them.",
)
with gr.Row():
replace_underscores = gr.Checkbox(
label="Replace underscores in filenames with spaces",
value=True,
interactive=True,
)
recursive = gr.Checkbox(
label="Recursive",
value=False,
info="Tag subfolders images as well",
)
debug = gr.Checkbox(
label="Verbose logging",
value=True,
info="Debug while tagging, it will print your image file with general tags and character tags.",
)
frequency_tags = gr.Checkbox(
label="Show tags frequency",
value=True,
info="Show frequency of tags for images.",
)
# Model Settings
with gr.Row():
repo_id = gr.Dropdown(
label="Repo ID",
choices=[
"SmilingWolf/wd-v1-4-convnext-tagger-v2",
"SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
"SmilingWolf/wd-v1-4-vit-tagger-v2",
"SmilingWolf/wd-v1-4-swinv2-tagger-v2",
"SmilingWolf/wd-v1-4-moat-tagger-v2",
# 'SmilingWolf/wd-swinv2-tagger-v3',
# 'SmilingWolf/wd-vit-tagger-v3',
# 'SmilingWolf/wd-convnext-tagger-v3',
],
value="SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
show_label="Repo id for wd14 tagger on Hugging Face",
)
force_download = gr.Checkbox(
label="Force model re-download",
value=False,
info="Useful to force model re download when switching to onnx",
)
general_threshold = gr.Slider(
value=0.35,
label="General threshold",
info="Adjust `general_threshold` for pruning tags (less tags, less flexible)",
minimum=0,
maximum=1,
step=0.05,
)
character_threshold = gr.Slider(
value=0.35,
label='Character threshold',
minimum=0,
maximum=1,
step=0.05,
)
# Advanced Settings
with gr.Row():
batch_size = gr.Number(value=8, label="Batch size", interactive=True)
max_data_loader_n_workers = gr.Number(
value=2, label="Max dataloader workers", interactive=True
)
caption_button = gr.Button("Caption images")
caption_button.click(
caption_images,
inputs=[
train_data_dir,
caption_extension,
batch_size,
general_threshold,
character_threshold,
replace_underscores,
repo_id,
recursive,
max_data_loader_n_workers,
debug,
undesired_tags,
frequency_tags,
prefix,
postfix,
onnx,
append_tags,
force_download,
caption_separator,
],
show_progress=False,
)
train_data_dir.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_train_dirs(path)),
inputs=train_data_dir,
outputs=train_data_dir,
show_progress=False,
)