mirror of https://github.com/bmaltais/kohya_ss
316 lines
10 KiB
Python
316 lines
10 KiB
Python
import gradio as gr
|
|
from easygui import msgbox
|
|
import subprocess
|
|
from .common_gui import get_folder_path, add_pre_postfix, scriptdir, list_dirs
|
|
import os
|
|
|
|
from .custom_logging import setup_logging
|
|
|
|
# Set up logging
|
|
log = setup_logging()
|
|
|
|
|
|
def caption_images(
|
|
train_data_dir: str,
|
|
caption_extension: str,
|
|
batch_size: int,
|
|
general_threshold: float,
|
|
character_threshold: float,
|
|
replace_underscores: bool,
|
|
repo_id: str,
|
|
recursive: bool,
|
|
max_data_loader_n_workers: int,
|
|
debug: bool,
|
|
undesired_tags: str,
|
|
frequency_tags: bool,
|
|
prefix: str,
|
|
postfix: str,
|
|
onnx: bool,
|
|
append_tags: bool,
|
|
force_download: bool,
|
|
caption_separator: str,
|
|
) -> None:
|
|
"""
|
|
Captions images in a given directory using the WD14 model.
|
|
|
|
Args:
|
|
train_data_dir (str): The directory containing the images to be captioned.
|
|
caption_extension (str): The extension to be used for the caption files.
|
|
batch_size (int): The batch size for the captioning process.
|
|
general_threshold (float): The general threshold for the captioning process.
|
|
character_threshold (float): The character threshold for the captioning process.
|
|
replace_underscores (bool): Whether to replace underscores in filenames with spaces.
|
|
repo_id (str): The ID of the repository containing the WD14 model.
|
|
recursive (bool): Whether to process subdirectories recursively.
|
|
max_data_loader_n_workers (int): The maximum number of workers for the data loader.
|
|
debug (bool): Whether to enable debug mode.
|
|
undesired_tags (str): Comma-separated list of tags to be removed from the captions.
|
|
frequency_tags (bool): Whether to include frequency tags in the captions.
|
|
prefix (str): The prefix to be added to the captions.
|
|
postfix (str): The postfix to be added to the captions.
|
|
onnx (bool): Whether to use ONNX for the captioning process.
|
|
append_tags (bool): Whether to append tags to existing tags.
|
|
force_download (bool): Whether to force the model to be downloaded.
|
|
caption_separator (str): The separator to be used for the captions.
|
|
"""
|
|
# Check for images_dir_input
|
|
if train_data_dir == "":
|
|
msgbox("Image folder is missing...")
|
|
return
|
|
|
|
if caption_extension == "":
|
|
msgbox("Please provide an extension for the caption files.")
|
|
return
|
|
|
|
log.info(f"Captioning files in {train_data_dir}...")
|
|
run_cmd = rf'accelerate launch "{scriptdir}/sd-scripts/finetune/tag_images_by_wd14_tagger.py"'
|
|
if append_tags:
|
|
run_cmd += f" --append_tags"
|
|
run_cmd += f" --batch_size={int(batch_size)}"
|
|
run_cmd += f' --caption_extension="{caption_extension}"'
|
|
run_cmd += f' --caption_separator="{caption_separator}"'
|
|
run_cmd += f" --character_threshold={character_threshold}"
|
|
if debug:
|
|
run_cmd += f" --debug"
|
|
if force_download:
|
|
run_cmd += f" --force_download"
|
|
if frequency_tags:
|
|
run_cmd += f" --frequency_tags"
|
|
run_cmd += f" --general_threshold={general_threshold}"
|
|
run_cmd += f' --max_data_loader_n_workers="{int(max_data_loader_n_workers)}"'
|
|
if onnx:
|
|
run_cmd += f" --onnx"
|
|
if recursive:
|
|
run_cmd += f" --recursive"
|
|
if replace_underscores:
|
|
run_cmd += f" --remove_underscore"
|
|
run_cmd += f' --repo_id="{repo_id}"'
|
|
if not undesired_tags == "":
|
|
run_cmd += f' --undesired_tags="{undesired_tags}"'
|
|
run_cmd += rf' "{train_data_dir}"'
|
|
|
|
log.info(run_cmd)
|
|
|
|
env = os.environ.copy()
|
|
env["PYTHONPATH"] = (
|
|
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
|
|
)
|
|
|
|
# Run the command
|
|
subprocess.run(run_cmd, shell=True, env=env)
|
|
|
|
# Add prefix and postfix
|
|
add_pre_postfix(
|
|
folder=train_data_dir,
|
|
caption_file_ext=caption_extension,
|
|
prefix=prefix,
|
|
postfix=postfix,
|
|
)
|
|
|
|
log.info("...captioning done")
|
|
|
|
|
|
###
|
|
# Gradio UI
|
|
###
|
|
|
|
|
|
def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None):
|
|
from .common_gui import create_refresh_button
|
|
|
|
default_train_dir = (
|
|
default_train_dir
|
|
if default_train_dir is not None
|
|
else os.path.join(scriptdir, "data")
|
|
)
|
|
current_train_dir = default_train_dir
|
|
|
|
def list_train_dirs(path):
|
|
nonlocal current_train_dir
|
|
current_train_dir = path
|
|
return list(list_dirs(path))
|
|
|
|
with gr.Tab("WD14 Captioning"):
|
|
gr.Markdown(
|
|
"This utility will use WD14 to caption files for each images in a folder."
|
|
)
|
|
|
|
# Input Settings
|
|
# with gr.Section('Input Settings'):
|
|
with gr.Group(), gr.Row():
|
|
train_data_dir = gr.Dropdown(
|
|
label="Image folder to caption (containing the images to caption)",
|
|
choices=[""] + list_train_dirs(default_train_dir),
|
|
value="",
|
|
interactive=True,
|
|
allow_custom_value=True,
|
|
)
|
|
create_refresh_button(
|
|
train_data_dir,
|
|
lambda: None,
|
|
lambda: {"choices": list_train_dirs(current_train_dir)},
|
|
"open_folder_small",
|
|
)
|
|
button_train_data_dir_input = gr.Button(
|
|
"📂",
|
|
elem_id="open_folder_small",
|
|
elem_classes=["tool"],
|
|
visible=(not headless),
|
|
)
|
|
button_train_data_dir_input.click(
|
|
get_folder_path,
|
|
outputs=train_data_dir,
|
|
show_progress=False,
|
|
)
|
|
|
|
caption_extension = gr.Textbox(
|
|
label='Caption file extension',
|
|
placeholder='Extension for caption file (e.g., .caption, .txt)',
|
|
value='.txt',
|
|
interactive=True,
|
|
)
|
|
|
|
caption_separator = gr.Textbox(
|
|
label="Caption Separator",
|
|
value=",",
|
|
interactive=True,
|
|
)
|
|
|
|
undesired_tags = gr.Textbox(
|
|
label="Undesired tags",
|
|
placeholder="(Optional) Separate `undesired_tags` with comma `(,)` if you want to remove multiple tags, e.g. `1girl,solo,smile`.",
|
|
interactive=True,
|
|
)
|
|
|
|
with gr.Row():
|
|
prefix = gr.Textbox(
|
|
label="Prefix to add to WD14 caption",
|
|
placeholder="(Optional)",
|
|
interactive=True,
|
|
)
|
|
|
|
postfix = gr.Textbox(
|
|
label="Postfix to add to WD14 caption",
|
|
placeholder="(Optional)",
|
|
interactive=True,
|
|
)
|
|
|
|
with gr.Row():
|
|
onnx = gr.Checkbox(
|
|
label="Use onnx",
|
|
value=False,
|
|
interactive=True,
|
|
info="https://github.com/onnx/onnx",
|
|
)
|
|
append_tags = gr.Checkbox(
|
|
label="Append TAGs",
|
|
value=False,
|
|
interactive=True,
|
|
info="This option appends the tags to the existing tags, instead of replacing them.",
|
|
)
|
|
|
|
with gr.Row():
|
|
replace_underscores = gr.Checkbox(
|
|
label="Replace underscores in filenames with spaces",
|
|
value=True,
|
|
interactive=True,
|
|
)
|
|
recursive = gr.Checkbox(
|
|
label="Recursive",
|
|
value=False,
|
|
info="Tag subfolders images as well",
|
|
)
|
|
|
|
debug = gr.Checkbox(
|
|
label="Verbose logging",
|
|
value=True,
|
|
info="Debug while tagging, it will print your image file with general tags and character tags.",
|
|
)
|
|
frequency_tags = gr.Checkbox(
|
|
label="Show tags frequency",
|
|
value=True,
|
|
info="Show frequency of tags for images.",
|
|
)
|
|
|
|
# Model Settings
|
|
with gr.Row():
|
|
repo_id = gr.Dropdown(
|
|
label="Repo ID",
|
|
choices=[
|
|
"SmilingWolf/wd-v1-4-convnext-tagger-v2",
|
|
"SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
|
|
"SmilingWolf/wd-v1-4-vit-tagger-v2",
|
|
"SmilingWolf/wd-v1-4-swinv2-tagger-v2",
|
|
"SmilingWolf/wd-v1-4-moat-tagger-v2",
|
|
# 'SmilingWolf/wd-swinv2-tagger-v3',
|
|
# 'SmilingWolf/wd-vit-tagger-v3',
|
|
# 'SmilingWolf/wd-convnext-tagger-v3',
|
|
],
|
|
value="SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
|
|
show_label="Repo id for wd14 tagger on Hugging Face",
|
|
)
|
|
|
|
force_download = gr.Checkbox(
|
|
label="Force model re-download",
|
|
value=False,
|
|
info="Useful to force model re download when switching to onnx",
|
|
)
|
|
|
|
general_threshold = gr.Slider(
|
|
value=0.35,
|
|
label="General threshold",
|
|
info="Adjust `general_threshold` for pruning tags (less tags, less flexible)",
|
|
minimum=0,
|
|
maximum=1,
|
|
step=0.05,
|
|
)
|
|
character_threshold = gr.Slider(
|
|
value=0.35,
|
|
label='Character threshold',
|
|
minimum=0,
|
|
maximum=1,
|
|
step=0.05,
|
|
)
|
|
|
|
# Advanced Settings
|
|
with gr.Row():
|
|
batch_size = gr.Number(value=8, label="Batch size", interactive=True)
|
|
|
|
max_data_loader_n_workers = gr.Number(
|
|
value=2, label="Max dataloader workers", interactive=True
|
|
)
|
|
|
|
caption_button = gr.Button("Caption images")
|
|
|
|
caption_button.click(
|
|
caption_images,
|
|
inputs=[
|
|
train_data_dir,
|
|
caption_extension,
|
|
batch_size,
|
|
general_threshold,
|
|
character_threshold,
|
|
replace_underscores,
|
|
repo_id,
|
|
recursive,
|
|
max_data_loader_n_workers,
|
|
debug,
|
|
undesired_tags,
|
|
frequency_tags,
|
|
prefix,
|
|
postfix,
|
|
onnx,
|
|
append_tags,
|
|
force_download,
|
|
caption_separator,
|
|
],
|
|
show_progress=False,
|
|
)
|
|
|
|
train_data_dir.change(
|
|
fn=lambda path: gr.Dropdown(choices=[""] + list_train_dirs(path)),
|
|
inputs=train_data_dir,
|
|
outputs=train_data_dir,
|
|
show_progress=False,
|
|
)
|