mirror of https://github.com/bmaltais/kohya_ss
Implement use_shell as parameter (#2297)
parent
f4658f9e01
commit
e6a8dec98d
|
|
@ -1,6 +1,9 @@
|
|||
# Copy this file and name it config.toml
|
||||
# Edit the values to suit your needs
|
||||
|
||||
[settings]
|
||||
use_shell = false # Use shell furing process run of sd-scripts oython code. Most secure is false but some systems may require it to be true to properly run sd-scripts.
|
||||
|
||||
# Default folders location
|
||||
[model]
|
||||
models_dir = "./models" # Pretrained model name or path
|
||||
|
|
|
|||
40
kohya_gui.py
40
kohya_gui.py
|
|
@ -12,6 +12,7 @@ from kohya_gui.class_lora_tab import LoRATools
|
|||
from kohya_gui.custom_logging import setup_logging
|
||||
from kohya_gui.localization_ext import add_javascript
|
||||
|
||||
|
||||
def UI(**kwargs):
|
||||
add_javascript(kwargs.get("language"))
|
||||
css = ""
|
||||
|
|
@ -35,9 +36,18 @@ def UI(**kwargs):
|
|||
interface = gr.Blocks(
|
||||
css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
|
||||
)
|
||||
|
||||
|
||||
config = KohyaSSGUIConfig(config_file_path=kwargs.get("config"))
|
||||
|
||||
if config.is_config_loaded():
|
||||
log.info(f"Loaded default GUI values from '{kwargs.get('config')}'...")
|
||||
|
||||
use_shell_flag = kwargs.get("use_shell", False)
|
||||
if use_shell_flag == False:
|
||||
use_shell_flag = config.get("settings.use_shell", False)
|
||||
if use_shell_flag:
|
||||
log.info("Using shell=True when running external commands...")
|
||||
|
||||
with interface:
|
||||
with gr.Tab("Dreambooth"):
|
||||
(
|
||||
|
|
@ -45,13 +55,17 @@ def UI(**kwargs):
|
|||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
) = dreambooth_tab(headless=headless, config=config)
|
||||
) = dreambooth_tab(
|
||||
headless=headless, config=config, use_shell_flag=use_shell_flag
|
||||
)
|
||||
with gr.Tab("LoRA"):
|
||||
lora_tab(headless=headless, config=config)
|
||||
lora_tab(headless=headless, config=config, use_shell_flag=use_shell_flag)
|
||||
with gr.Tab("Textual Inversion"):
|
||||
ti_tab(headless=headless, config=config)
|
||||
ti_tab(headless=headless, config=config, use_shell_flag=use_shell_flag)
|
||||
with gr.Tab("Finetuning"):
|
||||
finetune_tab(headless=headless, config=config)
|
||||
finetune_tab(
|
||||
headless=headless, config=config, use_shell_flag=use_shell_flag
|
||||
)
|
||||
with gr.Tab("Utilities"):
|
||||
utilities_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
|
|
@ -61,9 +75,10 @@ def UI(**kwargs):
|
|||
enable_copy_info_button=True,
|
||||
headless=headless,
|
||||
config=config,
|
||||
use_shell_flag=use_shell_flag,
|
||||
)
|
||||
with gr.Tab("LoRA"):
|
||||
_ = LoRATools(headless=headless)
|
||||
_ = LoRATools(headless=headless, use_shell_flag=use_shell_flag)
|
||||
with gr.Tab("About"):
|
||||
gr.Markdown(f"kohya_ss GUI release {release}")
|
||||
with gr.Tab("README"):
|
||||
|
|
@ -102,6 +117,7 @@ def UI(**kwargs):
|
|||
launch_kwargs["debug"] = True
|
||||
interface.launch(**launch_kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
@ -141,11 +157,17 @@ if __name__ == "__main__":
|
|||
|
||||
parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
|
||||
parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment")
|
||||
|
||||
parser.add_argument("--do_not_share", action="store_true", help="Do not share the gradio UI")
|
||||
|
||||
parser.add_argument(
|
||||
"--use_shell", action="store_true", help="Use shell environment"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--do_not_share", action="store_true", help="Do not share the gradio UI"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging(debug=args.debug)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ def caption_images(
|
|||
beam_search: bool,
|
||||
prefix: str = "",
|
||||
postfix: str = "",
|
||||
use_shell: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Automatically generates captions for images in the specified directory using the BLIP model.
|
||||
|
|
@ -96,7 +97,7 @@ def caption_images(
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Run the command in the sd-scripts folder context
|
||||
subprocess.run(run_cmd, env=env, cwd=f"{scriptdir}/sd-scripts")
|
||||
subprocess.run(run_cmd, env=env, shell=use_shell, cwd=f"{scriptdir}/sd-scripts")
|
||||
|
||||
|
||||
# Add prefix and postfix
|
||||
|
|
@ -115,7 +116,7 @@ def caption_images(
|
|||
###
|
||||
|
||||
|
||||
def gradio_blip_caption_gui_tab(headless=False, default_train_dir=None):
|
||||
def gradio_blip_caption_gui_tab(headless=False, default_train_dir=None, use_shell: bool = False):
|
||||
from .common_gui import create_refresh_button
|
||||
|
||||
default_train_dir = (
|
||||
|
|
@ -205,6 +206,7 @@ def gradio_blip_caption_gui_tab(headless=False, default_train_dir=None):
|
|||
beam_search,
|
||||
prefix,
|
||||
postfix,
|
||||
gr.Checkbox(value=use_shell, visible=False),
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import subprocess
|
|||
import psutil
|
||||
import time
|
||||
import gradio as gr
|
||||
import shlex
|
||||
|
||||
from .custom_logging import setup_logging
|
||||
|
||||
# Set up logging
|
||||
|
|
@ -21,7 +21,7 @@ class CommandExecutor:
|
|||
self.process = None
|
||||
self.run_state = gr.Textbox(value="", visible=False)
|
||||
|
||||
def execute_command(self, run_cmd: str, **kwargs):
|
||||
def execute_command(self, run_cmd: str, use_shell: bool = False, **kwargs):
|
||||
"""
|
||||
Execute a command if no other command is currently running.
|
||||
|
||||
|
|
@ -36,11 +36,12 @@ class CommandExecutor:
|
|||
# log.info(f"{i}: {item}")
|
||||
|
||||
# Reconstruct the safe command string for display
|
||||
command_to_run = ' '.join(run_cmd)
|
||||
log.info(f"Executing command: {command_to_run}")
|
||||
command_to_run = " ".join(run_cmd)
|
||||
log.info(f"Executing command: {command_to_run} with shell={use_shell}")
|
||||
|
||||
# Execute the command securely
|
||||
self.process = subprocess.Popen(run_cmd, **kwargs)
|
||||
self.process = subprocess.Popen(run_cmd, **kwargs, shell=use_shell)
|
||||
log.info("Command executed.")
|
||||
|
||||
def kill_command(self):
|
||||
"""
|
||||
|
|
@ -64,9 +65,9 @@ class CommandExecutor:
|
|||
log.info(f"Error when terminating process: {e}")
|
||||
else:
|
||||
log.info("There is no running process to kill.")
|
||||
|
||||
|
||||
return gr.Button(visible=True), gr.Button(visible=False)
|
||||
|
||||
|
||||
def wait_for_training_to_end(self):
|
||||
while self.is_running():
|
||||
time.sleep(1)
|
||||
|
|
@ -81,4 +82,4 @@ class CommandExecutor:
|
|||
Returns:
|
||||
- bool: True if the command is running, False otherwise.
|
||||
"""
|
||||
return self.process and self.process.poll() is None
|
||||
return self.process and self.process.poll() is None
|
||||
|
|
|
|||
|
|
@ -80,3 +80,14 @@ class KohyaSSGUIConfig:
|
|||
# Return the final value
|
||||
log.debug(f"Returned {data}")
|
||||
return data
|
||||
|
||||
def is_config_loaded(self) -> bool:
|
||||
"""
|
||||
Checks if the configuration was loaded from a file.
|
||||
|
||||
Returns:
|
||||
bool: True if the configuration was loaded from a file, False otherwise.
|
||||
"""
|
||||
is_loaded = self.config != {}
|
||||
log.debug(f"Configuration was loaded from file: {is_loaded}")
|
||||
return is_loaded
|
||||
|
|
|
|||
|
|
@ -11,16 +11,18 @@ from .merge_lycoris_gui import gradio_merge_lycoris_tab
|
|||
|
||||
|
||||
class LoRATools:
|
||||
def __init__(self, headless: bool = False):
|
||||
self.headless = headless
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
headless: bool = False,
|
||||
use_shell_flag: bool = False,
|
||||
):
|
||||
gr.Markdown("This section provide various LoRA tools...")
|
||||
gradio_extract_dylora_tab(headless=headless)
|
||||
gradio_convert_lcm_tab(headless=headless)
|
||||
gradio_extract_lora_tab(headless=headless)
|
||||
gradio_extract_lycoris_locon_tab(headless=headless)
|
||||
gradio_merge_lora_tab = GradioMergeLoRaTab()
|
||||
gradio_merge_lycoris_tab(headless=headless)
|
||||
gradio_svd_merge_lora_tab(headless=headless)
|
||||
gradio_resize_lora_tab(headless=headless)
|
||||
gradio_verify_lora_tab(headless=headless)
|
||||
gradio_extract_dylora_tab(headless=headless, use_shell=use_shell_flag)
|
||||
gradio_convert_lcm_tab(headless=headless, use_shell=use_shell_flag)
|
||||
gradio_extract_lora_tab(headless=headless, use_shell=use_shell_flag)
|
||||
gradio_extract_lycoris_locon_tab(headless=headless, use_shell=use_shell_flag)
|
||||
gradio_merge_lora_tab = GradioMergeLoRaTab(use_shell=use_shell_flag)
|
||||
gradio_merge_lycoris_tab(headless=headless, use_shell=use_shell_flag)
|
||||
gradio_svd_merge_lora_tab(headless=headless, use_shell=use_shell_flag)
|
||||
gradio_resize_lora_tab(headless=headless, use_shell=use_shell_flag)
|
||||
gradio_verify_lora_tab(headless=headless, use_shell=use_shell_flag)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from tkinter import filedialog, Tk
|
|||
from easygui import msgbox, ynbox
|
||||
from typing import Optional
|
||||
from .custom_logging import setup_logging
|
||||
from .class_command_executor import CommandExecutor
|
||||
|
||||
import os
|
||||
import re
|
||||
|
|
@ -12,7 +11,6 @@ import shlex
|
|||
import json
|
||||
import math
|
||||
import shutil
|
||||
import time
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
|
@ -23,6 +21,7 @@ save_style_symbol = "\U0001f4be" # 💾
|
|||
document_symbol = "\U0001F4C4" # 📄
|
||||
|
||||
scriptdir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
if os.name == "nt":
|
||||
scriptdir = scriptdir.replace("\\", "/")
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,13 @@ document_symbol = "\U0001F4C4" # 📄
|
|||
PYTHON = sys.executable
|
||||
|
||||
|
||||
def convert_lcm(name, model_path, lora_scale, model_type):
|
||||
def convert_lcm(
|
||||
name,
|
||||
model_path,
|
||||
lora_scale,
|
||||
model_type,
|
||||
use_shell: bool = False,
|
||||
):
|
||||
run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/lcm_convert.py"'
|
||||
|
||||
# Check if source model exist
|
||||
|
|
@ -62,7 +68,7 @@ def convert_lcm(name, model_path, lora_scale, model_type):
|
|||
run_cmd.append("--ssd-1b")
|
||||
|
||||
# Log the command
|
||||
log.info(' '.join(run_cmd))
|
||||
log.info(" ".join(run_cmd))
|
||||
|
||||
# Set up the environment
|
||||
env = os.environ.copy()
|
||||
|
|
@ -72,13 +78,13 @@ def convert_lcm(name, model_path, lora_scale, model_type):
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, env=env)
|
||||
subprocess.run(run_cmd, env=env, shell=use_shell)
|
||||
|
||||
# Return a success message
|
||||
log.info("Done extracting...")
|
||||
|
||||
|
||||
def gradio_convert_lcm_tab(headless=False):
|
||||
def gradio_convert_lcm_tab(headless=False, use_shell: bool = False):
|
||||
current_model_dir = os.path.join(scriptdir, "outputs")
|
||||
current_save_dir = os.path.join(scriptdir, "outputs")
|
||||
|
||||
|
|
@ -183,6 +189,12 @@ def gradio_convert_lcm_tab(headless=False):
|
|||
|
||||
extract_button.click(
|
||||
convert_lcm,
|
||||
inputs=[name, model_path, lora_scale, model_type],
|
||||
inputs=[
|
||||
name,
|
||||
model_path,
|
||||
lora_scale,
|
||||
model_type,
|
||||
gr.Checkbox(value=use_shell, visible=False),
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ def convert_model(
|
|||
target_model_type,
|
||||
target_save_precision_type,
|
||||
unet_use_linear_projection,
|
||||
use_shell: bool = False,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if source_model_type == "":
|
||||
|
|
@ -107,7 +108,7 @@ def convert_model(
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, env=env)
|
||||
subprocess.run(run_cmd, env=env, shell=use_shell)
|
||||
|
||||
|
||||
|
||||
|
|
@ -116,7 +117,7 @@ def convert_model(
|
|||
###
|
||||
|
||||
|
||||
def gradio_convert_model_tab(headless=False):
|
||||
def gradio_convert_model_tab(headless=False, use_shell: bool = False):
|
||||
from .common_gui import create_refresh_button
|
||||
|
||||
default_source_model = os.path.join(scriptdir, "outputs")
|
||||
|
|
@ -276,6 +277,7 @@ def gradio_convert_model_tab(headless=False):
|
|||
target_model_type,
|
||||
target_save_precision_type,
|
||||
unet_use_linear_projection,
|
||||
gr.Checkbox(value=use_shell, visible=False),
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ executor = CommandExecutor()
|
|||
|
||||
# Setup huggingface
|
||||
huggingface = None
|
||||
use_shell = False
|
||||
|
||||
PYTHON = sys.executable
|
||||
|
||||
|
|
@ -843,7 +844,7 @@ def train_model(
|
|||
|
||||
# Run the command
|
||||
|
||||
executor.execute_command(run_cmd=run_cmd, env=env)
|
||||
executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env)
|
||||
|
||||
return (
|
||||
gr.Button(visible=False),
|
||||
|
|
@ -859,10 +860,14 @@ def dreambooth_tab(
|
|||
# logging_dir=gr.Textbox(),
|
||||
headless=False,
|
||||
config: KohyaSSGUIConfig = {},
|
||||
use_shell_flag: bool = False,
|
||||
):
|
||||
dummy_db_true = gr.Checkbox(value=True, visible=False)
|
||||
dummy_db_false = gr.Checkbox(value=False, visible=False)
|
||||
dummy_headless = gr.Checkbox(value=headless, visible=False)
|
||||
|
||||
global use_shell
|
||||
use_shell = use_shell_flag
|
||||
|
||||
with gr.Tab("Training"), gr.Column(variant="compact"):
|
||||
gr.Markdown("Train a custom model using kohya dreambooth python code...")
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ def extract_dylora(
|
|||
model,
|
||||
save_to,
|
||||
unit,
|
||||
use_shell: bool = False,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if model == "":
|
||||
|
|
@ -71,7 +72,7 @@ def extract_dylora(
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, env=env)
|
||||
subprocess.run(run_cmd, env=env, shell=use_shell)
|
||||
|
||||
log.info("Done extracting DyLoRA...")
|
||||
|
||||
|
|
@ -81,7 +82,7 @@ def extract_dylora(
|
|||
###
|
||||
|
||||
|
||||
def gradio_extract_dylora_tab(headless=False):
|
||||
def gradio_extract_dylora_tab(headless=False, use_shell: bool = False):
|
||||
current_model_dir = os.path.join(scriptdir, "outputs")
|
||||
current_save_dir = os.path.join(scriptdir, "outputs")
|
||||
|
||||
|
|
@ -170,6 +171,7 @@ def gradio_extract_dylora_tab(headless=False):
|
|||
model,
|
||||
save_to,
|
||||
unit,
|
||||
gr.Checkbox(value=use_shell, visible=False),
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ def extract_lora(
|
|||
load_original_model_to,
|
||||
load_tuned_model_to,
|
||||
load_precision,
|
||||
use_shell: bool = False,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if model_tuned == "":
|
||||
|
|
@ -120,7 +121,7 @@ def extract_lora(
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, env=env)
|
||||
subprocess.run(run_cmd, env=env, shell=use_shell)
|
||||
|
||||
|
||||
###
|
||||
|
|
@ -128,7 +129,7 @@ def extract_lora(
|
|||
###
|
||||
|
||||
|
||||
def gradio_extract_lora_tab(headless=False):
|
||||
def gradio_extract_lora_tab(headless=False, use_shell: bool = False):
|
||||
current_model_dir = os.path.join(scriptdir, "outputs")
|
||||
current_model_org_dir = os.path.join(scriptdir, "outputs")
|
||||
current_save_dir = os.path.join(scriptdir, "outputs")
|
||||
|
|
@ -358,6 +359,7 @@ def gradio_extract_lora_tab(headless=False):
|
|||
load_original_model_to,
|
||||
load_tuned_model_to,
|
||||
load_precision,
|
||||
gr.Checkbox(value=use_shell, visible=False),
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ def extract_lycoris_locon(
|
|||
use_sparse_bias,
|
||||
sparsity,
|
||||
disable_cp,
|
||||
use_shell: bool = False,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if db_model == "":
|
||||
|
|
@ -135,7 +136,7 @@ def extract_lycoris_locon(
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, env=env)
|
||||
subprocess.run(run_cmd, env=env, shell=use_shell)
|
||||
|
||||
log.info("Done extracting...")
|
||||
|
||||
|
|
@ -171,7 +172,7 @@ def update_mode(mode):
|
|||
return tuple(updates)
|
||||
|
||||
|
||||
def gradio_extract_lycoris_locon_tab(headless=False):
|
||||
def gradio_extract_lycoris_locon_tab(headless=False, use_shell: bool = False):
|
||||
|
||||
current_model_dir = os.path.join(scriptdir, "outputs")
|
||||
current_base_model_dir = os.path.join(scriptdir, "outputs")
|
||||
|
|
@ -449,6 +450,7 @@ def gradio_extract_lycoris_locon_tab(headless=False):
|
|||
use_sparse_bias,
|
||||
sparsity,
|
||||
disable_cp,
|
||||
gr.Checkbox(value=use_shell, visible=False),
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from .class_tensorboard import TensorboardManager
|
|||
from .class_sample_images import SampleImages, create_prompt_file
|
||||
from .class_huggingface import HuggingFace
|
||||
from .class_metadata import MetaData
|
||||
from .class_gui_config import KohyaSSGUIConfig
|
||||
|
||||
from .custom_logging import setup_logging
|
||||
|
||||
|
|
@ -43,6 +44,7 @@ executor = CommandExecutor()
|
|||
|
||||
# Setup huggingface
|
||||
huggingface = None
|
||||
use_shell = False
|
||||
|
||||
# from easygui import msgbox
|
||||
|
||||
|
|
@ -592,7 +594,6 @@ def train_model(
|
|||
if not print_only:
|
||||
subprocess.run(run_cmd, env=env)
|
||||
|
||||
|
||||
# create images buckets
|
||||
if generate_image_buckets:
|
||||
# Build the command to run the preparation script
|
||||
|
|
@ -639,7 +640,6 @@ def train_model(
|
|||
if not print_only:
|
||||
subprocess.run(run_cmd, env=env)
|
||||
|
||||
|
||||
if image_folder == "":
|
||||
log.error("Image folder dir is empty")
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
|
@ -709,7 +709,7 @@ def train_model(
|
|||
)
|
||||
cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs
|
||||
no_half_vae = sdxl_checkbox and sdxl_no_half_vae
|
||||
|
||||
|
||||
if max_data_loader_n_workers == "" or None:
|
||||
max_data_loader_n_workers = 0
|
||||
else:
|
||||
|
|
@ -719,7 +719,7 @@ def train_model(
|
|||
max_train_steps = 0
|
||||
else:
|
||||
max_train_steps = int(max_train_steps)
|
||||
|
||||
|
||||
config_toml_data = {
|
||||
# Update the values in the TOML data
|
||||
"huggingface_repo_id": huggingface_repo_id,
|
||||
|
|
@ -758,16 +758,22 @@ def train_model(
|
|||
"ip_noise_gamma": ip_noise_gamma,
|
||||
"ip_noise_gamma_random_strength": ip_noise_gamma_random_strength,
|
||||
"keep_tokens": int(keep_tokens),
|
||||
"learning_rate": learning_rate, # both for sd1.5 and sdxl
|
||||
"learning_rate_te": learning_rate_te if not sdxl_checkbox else None, # only for sd1.5
|
||||
"learning_rate_te1": learning_rate_te1 if sdxl_checkbox else None, # only for sdxl
|
||||
"learning_rate_te2": learning_rate_te2 if sdxl_checkbox else None, # only for sdxl
|
||||
"learning_rate": learning_rate, # both for sd1.5 and sdxl
|
||||
"learning_rate_te": (
|
||||
learning_rate_te if not sdxl_checkbox else None
|
||||
), # only for sd1.5
|
||||
"learning_rate_te1": (
|
||||
learning_rate_te1 if sdxl_checkbox else None
|
||||
), # only for sdxl
|
||||
"learning_rate_te2": (
|
||||
learning_rate_te2 if sdxl_checkbox else None
|
||||
), # only for sdxl
|
||||
"logging_dir": logging_dir,
|
||||
"log_tracker_name": log_tracker_name,
|
||||
"log_tracker_config": log_tracker_config,
|
||||
"loss_type": loss_type,
|
||||
"lr_scheduler": lr_scheduler,
|
||||
"lr_scheduler_args": str(lr_scheduler_args).replace('"', '').split(),
|
||||
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
|
||||
"lr_warmup_steps": lr_warmup_steps,
|
||||
"max_bucket_reso": int(max_bucket_reso),
|
||||
"max_data_loader_n_workers": max_data_loader_n_workers,
|
||||
|
|
@ -792,7 +798,7 @@ def train_model(
|
|||
"noise_offset_random_strength": noise_offset_random_strength,
|
||||
"noise_offset_type": noise_offset_type,
|
||||
"optimizer_type": optimizer,
|
||||
"optimizer_args": str(optimizer_args).replace('"', '').split(),
|
||||
"optimizer_args": str(optimizer_args).replace('"', "").split(),
|
||||
"output_dir": output_dir,
|
||||
"output_name": output_name,
|
||||
"persistent_data_loader_workers": persistent_data_loader_workers,
|
||||
|
|
@ -892,7 +898,7 @@ def train_model(
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Run the command
|
||||
executor.execute_command(run_cmd=run_cmd, env=env)
|
||||
executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env)
|
||||
|
||||
return (
|
||||
gr.Button(visible=False),
|
||||
|
|
@ -901,10 +907,18 @@ def train_model(
|
|||
)
|
||||
|
||||
|
||||
def finetune_tab(headless=False, config: dict = {}):
|
||||
def finetune_tab(
|
||||
headless=False,
|
||||
config: KohyaSSGUIConfig = {},
|
||||
use_shell_flag: bool = False,
|
||||
):
|
||||
dummy_db_true = gr.Checkbox(value=True, visible=False)
|
||||
dummy_db_false = gr.Checkbox(value=False, visible=False)
|
||||
dummy_headless = gr.Checkbox(value=headless, visible=False)
|
||||
|
||||
global use_shell
|
||||
use_shell = use_shell_flag
|
||||
|
||||
with gr.Tab("Training"), gr.Column(variant="compact"):
|
||||
gr.Markdown("Train a custom model using kohya finetune python code...")
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ def caption_images(
|
|||
model_id,
|
||||
prefix,
|
||||
postfix,
|
||||
use_shell: bool = False,
|
||||
):
|
||||
# Check for images_dir_input
|
||||
if train_data_dir == "":
|
||||
|
|
@ -70,7 +71,7 @@ def caption_images(
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, env=env)
|
||||
subprocess.run(run_cmd, env=env, shell=use_shell)
|
||||
|
||||
# Add prefix and postfix
|
||||
add_pre_postfix(
|
||||
|
|
@ -88,7 +89,9 @@ def caption_images(
|
|||
###
|
||||
|
||||
|
||||
def gradio_git_caption_gui_tab(headless=False, default_train_dir=None):
|
||||
def gradio_git_caption_gui_tab(
|
||||
headless=False, default_train_dir=None, use_shell: bool = False
|
||||
):
|
||||
from .common_gui import create_refresh_button
|
||||
|
||||
default_train_dir = (
|
||||
|
|
@ -178,6 +181,7 @@ def gradio_git_caption_gui_tab(headless=False, default_train_dir=None):
|
|||
model_id,
|
||||
prefix,
|
||||
postfix,
|
||||
gr.Checkbox(value=use_shell, visible=False),
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ def group_images(
|
|||
do_not_copy_other_files,
|
||||
generate_captions,
|
||||
caption_ext,
|
||||
use_shell: bool = False,
|
||||
):
|
||||
if input_folder == "":
|
||||
msgbox("Input folder is missing...")
|
||||
|
|
@ -63,12 +64,12 @@ def group_images(
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, env=env)
|
||||
subprocess.run(run_cmd, env=env, shell=use_shell)
|
||||
|
||||
log.info("...grouping done")
|
||||
|
||||
|
||||
def gradio_group_images_gui_tab(headless=False):
|
||||
def gradio_group_images_gui_tab(headless=False, use_shell: bool = False):
|
||||
from .common_gui import create_refresh_button
|
||||
|
||||
current_input_folder = os.path.join(scriptdir, "data")
|
||||
|
|
@ -200,6 +201,7 @@ def gradio_group_images_gui_tab(headless=False):
|
|||
do_not_copy_other_files,
|
||||
generate_captions,
|
||||
caption_ext,
|
||||
gr.Checkbox(value=use_shell, visible=False),
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ from .class_sample_images import SampleImages, create_prompt_file
|
|||
from .class_lora_tab import LoRATools
|
||||
from .class_huggingface import HuggingFace
|
||||
from .class_metadata import MetaData
|
||||
from .class_gui_config import KohyaSSGUIConfig
|
||||
|
||||
from .dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
|
|
@ -50,6 +51,7 @@ executor = CommandExecutor()
|
|||
|
||||
# Setup huggingface
|
||||
huggingface = None
|
||||
use_shell = False
|
||||
|
||||
button_run = gr.Button("Start training", variant="primary")
|
||||
|
||||
|
|
@ -1193,7 +1195,7 @@ def train_model(
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Run the command
|
||||
executor.execute_command(run_cmd=run_cmd, env=env)
|
||||
executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env)
|
||||
|
||||
return (
|
||||
gr.Button(visible=False),
|
||||
|
|
@ -1208,11 +1210,15 @@ def lora_tab(
|
|||
output_dir_input=gr.Dropdown(),
|
||||
logging_dir_input=gr.Dropdown(),
|
||||
headless=False,
|
||||
config: dict = {},
|
||||
config: KohyaSSGUIConfig = {},
|
||||
use_shell_flag: bool = False,
|
||||
):
|
||||
dummy_db_true = gr.Checkbox(value=True, visible=False)
|
||||
dummy_db_false = gr.Checkbox(value=False, visible=False)
|
||||
dummy_headless = gr.Checkbox(value=headless, visible=False)
|
||||
|
||||
global use_shell
|
||||
use_shell = use_shell_flag
|
||||
|
||||
with gr.Tab("Training"), gr.Column(variant="compact") as tab:
|
||||
gr.Markdown(
|
||||
|
|
|
|||
|
|
@ -48,8 +48,9 @@ def verify_conditions(sd_model, lora_models):
|
|||
|
||||
|
||||
class GradioMergeLoRaTab:
|
||||
def __init__(self, headless=False):
|
||||
def __init__(self, headless=False, use_shell: bool = False):
|
||||
self.headless = headless
|
||||
self.use_shell = use_shell
|
||||
self.build_tab()
|
||||
|
||||
def save_inputs_to_json(self, file_path, inputs):
|
||||
|
|
@ -379,6 +380,7 @@ class GradioMergeLoRaTab:
|
|||
save_to,
|
||||
precision,
|
||||
save_precision,
|
||||
gr.Checkbox(value=self.use_shell, visible=False),
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
@ -398,6 +400,7 @@ class GradioMergeLoRaTab:
|
|||
save_to,
|
||||
precision,
|
||||
save_precision,
|
||||
use_shell: bool = False,
|
||||
):
|
||||
|
||||
log.info("Merge model...")
|
||||
|
|
@ -458,6 +461,6 @@ class GradioMergeLoRaTab:
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, env=env)
|
||||
subprocess.run(run_cmd, env=env, shell=use_shell)
|
||||
|
||||
log.info("Done merging...")
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ def merge_lycoris(
|
|||
device,
|
||||
is_sdxl,
|
||||
is_v2,
|
||||
use_shell: bool = False,
|
||||
):
|
||||
log.info("Merge model...")
|
||||
|
||||
|
|
@ -67,7 +68,7 @@ def merge_lycoris(
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Execute the command with the modified environment
|
||||
subprocess.run(run_cmd, env=env)
|
||||
subprocess.run(run_cmd, env=env, shell=use_shell)
|
||||
|
||||
log.info("Done merging...")
|
||||
|
||||
|
|
@ -77,7 +78,7 @@ def merge_lycoris(
|
|||
###
|
||||
|
||||
|
||||
def gradio_merge_lycoris_tab(headless=False):
|
||||
def gradio_merge_lycoris_tab(headless=False, use_shell: bool = False):
|
||||
current_model_dir = os.path.join(scriptdir, "outputs")
|
||||
current_lycoris_dir = current_model_dir
|
||||
current_save_dir = current_model_dir
|
||||
|
|
@ -250,6 +251,7 @@ def gradio_merge_lycoris_tab(headless=False):
|
|||
device,
|
||||
is_sdxl,
|
||||
is_v2,
|
||||
gr.Checkbox(value=use_shell, visible=False),
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ def resize_lora(
|
|||
dynamic_method,
|
||||
dynamic_param,
|
||||
verbose,
|
||||
use_shell: bool = False,
|
||||
):
|
||||
# Check for caption_text_input
|
||||
if model == "":
|
||||
|
|
@ -100,7 +101,7 @@ def resize_lora(
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, env=env)
|
||||
subprocess.run(run_cmd, env=env, shell=use_shell)
|
||||
|
||||
log.info("Done resizing...")
|
||||
|
||||
|
|
@ -110,7 +111,7 @@ def resize_lora(
|
|||
###
|
||||
|
||||
|
||||
def gradio_resize_lora_tab(headless=False):
|
||||
def gradio_resize_lora_tab(headless=False, use_shell: bool = False):
|
||||
current_model_dir = os.path.join(scriptdir, "outputs")
|
||||
current_save_dir = os.path.join(scriptdir, "outputs")
|
||||
|
||||
|
|
@ -246,6 +247,7 @@ def gradio_resize_lora_tab(headless=False):
|
|||
dynamic_method,
|
||||
dynamic_param,
|
||||
verbose,
|
||||
gr.Checkbox(value=use_shell, visible=False),
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ def svd_merge_lora(
|
|||
new_rank,
|
||||
new_conv_rank,
|
||||
device,
|
||||
use_shell: bool = False,
|
||||
):
|
||||
# Check if the output file already exists
|
||||
if os.path.isfile(save_to):
|
||||
|
|
@ -53,10 +54,14 @@ def svd_merge_lora(
|
|||
ratio_d /= total_ratio
|
||||
|
||||
run_cmd = [
|
||||
PYTHON, f"{scriptdir}/sd-scripts/networks/svd_merge_lora.py",
|
||||
'--save_precision', save_precision,
|
||||
'--precision', precision,
|
||||
'--save_to', save_to
|
||||
PYTHON,
|
||||
f"{scriptdir}/sd-scripts/networks/svd_merge_lora.py",
|
||||
"--save_precision",
|
||||
save_precision,
|
||||
"--precision",
|
||||
precision,
|
||||
"--save_to",
|
||||
save_to,
|
||||
]
|
||||
|
||||
# Variables for model paths and their ratios
|
||||
|
|
@ -82,17 +87,15 @@ def svd_merge_lora(
|
|||
pass
|
||||
|
||||
if models and ratios: # Ensure we have valid models and ratios before appending
|
||||
run_cmd.extend(['--models'] + models)
|
||||
run_cmd.extend(['--ratios'] + ratios)
|
||||
run_cmd.extend(["--models"] + models)
|
||||
run_cmd.extend(["--ratios"] + ratios)
|
||||
|
||||
run_cmd.extend([
|
||||
'--device', device,
|
||||
'--new_rank', new_rank,
|
||||
'--new_conv_rank', new_conv_rank
|
||||
])
|
||||
run_cmd.extend(
|
||||
["--device", device, "--new_rank", new_rank, "--new_conv_rank", new_conv_rank]
|
||||
)
|
||||
|
||||
# Log the command
|
||||
log.info(' '.join(run_cmd))
|
||||
log.info(" ".join(run_cmd))
|
||||
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = (
|
||||
|
|
@ -102,8 +105,7 @@ def svd_merge_lora(
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, env=env)
|
||||
|
||||
subprocess.run(run_cmd, env=env, shell=use_shell)
|
||||
|
||||
|
||||
###
|
||||
|
|
@ -111,7 +113,7 @@ def svd_merge_lora(
|
|||
###
|
||||
|
||||
|
||||
def gradio_svd_merge_lora_tab(headless=False):
|
||||
def gradio_svd_merge_lora_tab(headless=False, use_shell: bool = False):
|
||||
current_save_dir = os.path.join(scriptdir, "outputs")
|
||||
current_a_model_dir = current_save_dir
|
||||
current_b_model_dir = current_save_dir
|
||||
|
|
@ -406,6 +408,7 @@ def gradio_svd_merge_lora_tab(headless=False):
|
|||
new_rank,
|
||||
new_conv_rank,
|
||||
device,
|
||||
gr.Checkbox(value=use_shell, visible=False),
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from .dreambooth_folder_creation_gui import (
|
|||
)
|
||||
from .dataset_balancing_gui import gradio_dataset_balancing_tab
|
||||
from .class_sample_images import SampleImages, create_prompt_file
|
||||
from .class_gui_config import KohyaSSGUIConfig
|
||||
|
||||
from .custom_logging import setup_logging
|
||||
|
||||
|
|
@ -48,6 +49,7 @@ executor = CommandExecutor()
|
|||
|
||||
# Setup huggingface
|
||||
huggingface = None
|
||||
use_shell = False
|
||||
|
||||
TRAIN_BUTTON_VISIBLE = [gr.Button(visible=True), gr.Button(visible=False)]
|
||||
|
||||
|
|
@ -624,7 +626,7 @@ def train_model(
|
|||
run_cmd.append(f"{scriptdir}/sd-scripts/sdxl_train_textual_inversion.py")
|
||||
else:
|
||||
run_cmd.append(f"{scriptdir}/sd-scripts/train_textual_inversion.py")
|
||||
|
||||
|
||||
if max_data_loader_n_workers == "" or None:
|
||||
max_data_loader_n_workers = 0
|
||||
else:
|
||||
|
|
@ -634,7 +636,7 @@ def train_model(
|
|||
max_train_steps = 0
|
||||
else:
|
||||
max_train_steps = int(max_train_steps)
|
||||
|
||||
|
||||
# def save_huggingface_to_toml(self, toml_file_path: str):
|
||||
config_toml_data = {
|
||||
# Update the values in the TOML data
|
||||
|
|
@ -675,8 +677,10 @@ def train_model(
|
|||
"log_tracker_config": log_tracker_config,
|
||||
"loss_type": loss_type,
|
||||
"lr_scheduler": lr_scheduler,
|
||||
"lr_scheduler_args": str(lr_scheduler_args).replace('"', '').split(),
|
||||
"lr_scheduler_num_cycles": lr_scheduler_num_cycles if lr_scheduler_num_cycles != "" else int(epoch),
|
||||
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
|
||||
"lr_scheduler_num_cycles": (
|
||||
lr_scheduler_num_cycles if lr_scheduler_num_cycles != "" else int(epoch)
|
||||
),
|
||||
"lr_scheduler_power": lr_scheduler_power,
|
||||
"lr_warmup_steps": lr_warmup_steps,
|
||||
"max_bucket_reso": max_bucket_reso,
|
||||
|
|
@ -704,7 +708,7 @@ def train_model(
|
|||
"noise_offset_type": noise_offset_type,
|
||||
"num_vectors_per_token": int(num_vectors_per_token),
|
||||
"optimizer_type": optimizer,
|
||||
"optimizer_args": str(optimizer_args).replace('"', '').split(),
|
||||
"optimizer_args": str(optimizer_args).replace('"', "").split(),
|
||||
"output_dir": output_dir,
|
||||
"output_name": output_name,
|
||||
"persistent_data_loader_workers": persistent_data_loader_workers,
|
||||
|
|
@ -766,7 +770,7 @@ def train_model(
|
|||
|
||||
run_cmd.append(f"--config_file")
|
||||
run_cmd.append(tmpfilename)
|
||||
|
||||
|
||||
# Initialize a dictionary with always-included keyword arguments
|
||||
kwargs_for_training = {
|
||||
"max_data_loader_n_workers": max_data_loader_n_workers,
|
||||
|
|
@ -811,7 +815,7 @@ def train_model(
|
|||
|
||||
# Run the command
|
||||
|
||||
executor.execute_command(run_cmd=run_cmd, env=env)
|
||||
executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env)
|
||||
|
||||
return (
|
||||
gr.Button(visible=False),
|
||||
|
|
@ -820,11 +824,19 @@ def train_model(
|
|||
)
|
||||
|
||||
|
||||
def ti_tab(headless=False, default_output_dir=None, config: dict = {}):
|
||||
def ti_tab(
|
||||
headless=False,
|
||||
default_output_dir=None,
|
||||
config: KohyaSSGUIConfig = {},
|
||||
use_shell_flag: bool = False,
|
||||
):
|
||||
dummy_db_true = gr.Checkbox(value=True, visible=False)
|
||||
dummy_db_false = gr.Checkbox(value=False, visible=False)
|
||||
dummy_headless = gr.Checkbox(value=headless, visible=False)
|
||||
|
||||
global use_shell
|
||||
use_shell = use_shell_flag
|
||||
|
||||
current_embedding_dir = (
|
||||
default_output_dir
|
||||
if default_output_dir is not None and default_output_dir != ""
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from .git_caption_gui import gradio_git_caption_gui_tab
|
|||
from .wd14_caption_gui import gradio_wd14_caption_gui_tab
|
||||
from .manual_caption_gui import gradio_manual_caption_gui_tab
|
||||
from .group_images_gui import gradio_group_images_gui_tab
|
||||
from .class_gui_config import KohyaSSGUIConfig
|
||||
|
||||
|
||||
def utilities_tab(
|
||||
|
|
@ -18,17 +19,18 @@ def utilities_tab(
|
|||
enable_copy_info_button=bool(False),
|
||||
enable_dreambooth_tab=True,
|
||||
headless=False,
|
||||
config: dict = {},
|
||||
config: KohyaSSGUIConfig = {},
|
||||
use_shell_flag: bool = False,
|
||||
):
|
||||
with gr.Tab("Captioning"):
|
||||
gradio_basic_caption_gui_tab(headless=headless)
|
||||
gradio_blip_caption_gui_tab(headless=headless)
|
||||
gradio_blip_caption_gui_tab(headless=headless, use_shell=use_shell_flag)
|
||||
gradio_blip2_caption_gui_tab(headless=headless)
|
||||
gradio_git_caption_gui_tab(headless=headless)
|
||||
gradio_wd14_caption_gui_tab(headless=headless, config=config)
|
||||
gradio_git_caption_gui_tab(headless=headless, use_shell=use_shell_flag)
|
||||
gradio_wd14_caption_gui_tab(headless=headless, config=config, use_shell=use_shell_flag)
|
||||
gradio_manual_caption_gui_tab(headless=headless)
|
||||
gradio_convert_model_tab(headless=headless)
|
||||
gradio_group_images_gui_tab(headless=headless)
|
||||
gradio_convert_model_tab(headless=headless, use_shell=use_shell_flag)
|
||||
gradio_group_images_gui_tab(headless=headless, use_shell=use_shell_flag)
|
||||
|
||||
return (
|
||||
train_data_dir_input,
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ PYTHON = sys.executable
|
|||
|
||||
def verify_lora(
|
||||
lora_model,
|
||||
use_shell: bool = False,
|
||||
):
|
||||
# verify for caption_text_input
|
||||
if lora_model == "":
|
||||
|
|
@ -37,11 +38,13 @@ def verify_lora(
|
|||
|
||||
# Build the command to run check_lora_weights.py
|
||||
run_cmd = [
|
||||
PYTHON, f"{scriptdir}/sd-scripts/networks/check_lora_weights.py", lora_model
|
||||
PYTHON,
|
||||
f"{scriptdir}/sd-scripts/networks/check_lora_weights.py",
|
||||
lora_model,
|
||||
]
|
||||
|
||||
# Log the command
|
||||
log.info(' '.join(run_cmd))
|
||||
log.info(" ".join(run_cmd))
|
||||
|
||||
# Set the environment variable for the Python path
|
||||
env = os.environ.copy()
|
||||
|
|
@ -57,6 +60,7 @@ def verify_lora(
|
|||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env=env,
|
||||
shell=use_shell,
|
||||
)
|
||||
output, error = process.communicate()
|
||||
|
||||
|
|
@ -68,7 +72,7 @@ def verify_lora(
|
|||
###
|
||||
|
||||
|
||||
def gradio_verify_lora_tab(headless=False):
|
||||
def gradio_verify_lora_tab(headless=False, use_shell: bool = False):
|
||||
current_model_dir = os.path.join(scriptdir, "outputs")
|
||||
|
||||
def list_models(path):
|
||||
|
|
@ -139,6 +143,7 @@ def gradio_verify_lora_tab(headless=False):
|
|||
verify_lora,
|
||||
inputs=[
|
||||
lora_model,
|
||||
gr.Checkbox(value=use_shell, visible=False),
|
||||
],
|
||||
outputs=[lora_model_verif_output, lora_model_verif_error],
|
||||
show_progress=False,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,13 @@
|
|||
import gradio as gr
|
||||
from easygui import msgbox
|
||||
import subprocess
|
||||
from .common_gui import get_folder_path, add_pre_postfix, scriptdir, list_dirs, get_executable_path
|
||||
from .common_gui import (
|
||||
get_folder_path,
|
||||
add_pre_postfix,
|
||||
scriptdir,
|
||||
list_dirs,
|
||||
get_executable_path,
|
||||
)
|
||||
from .class_gui_config import KohyaSSGUIConfig
|
||||
import os
|
||||
|
||||
|
|
@ -34,6 +40,7 @@ def caption_images(
|
|||
use_rating_tags_as_last_tag: bool,
|
||||
remove_underscore: bool,
|
||||
thresh: float,
|
||||
use_shell: bool = False,
|
||||
) -> None:
|
||||
# Check for images_dir_input
|
||||
if train_data_dir == "":
|
||||
|
|
@ -46,7 +53,9 @@ def caption_images(
|
|||
|
||||
log.info(f"Captioning files in {train_data_dir}...")
|
||||
run_cmd = [
|
||||
get_executable_path("accelerate"), "launch", f"{scriptdir}/sd-scripts/finetune/tag_images_by_wd14_tagger.py"
|
||||
get_executable_path("accelerate"),
|
||||
"launch",
|
||||
f"{scriptdir}/sd-scripts/finetune/tag_images_by_wd14_tagger.py",
|
||||
]
|
||||
|
||||
# Uncomment and modify if needed
|
||||
|
|
@ -106,7 +115,7 @@ def caption_images(
|
|||
run_cmd.append(train_data_dir)
|
||||
|
||||
# Log the command
|
||||
log.info(' '.join(run_cmd))
|
||||
log.info(" ".join(run_cmd))
|
||||
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = (
|
||||
|
|
@ -116,9 +125,8 @@ def caption_images(
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd, env=env)
|
||||
subprocess.run(run_cmd, env=env, shell=use_shell)
|
||||
|
||||
|
||||
# Add prefix and postfix
|
||||
add_pre_postfix(
|
||||
folder=train_data_dir,
|
||||
|
|
@ -135,7 +143,10 @@ def caption_images(
|
|||
|
||||
|
||||
def gradio_wd14_caption_gui_tab(
|
||||
headless=False, default_train_dir=None, config: KohyaSSGUIConfig = {}
|
||||
headless=False,
|
||||
default_train_dir=None,
|
||||
config: KohyaSSGUIConfig = {},
|
||||
use_shell: bool = False,
|
||||
):
|
||||
from .common_gui import create_refresh_button
|
||||
|
||||
|
|
@ -374,6 +385,7 @@ def gradio_wd14_caption_gui_tab(
|
|||
use_rating_tags_as_last_tag,
|
||||
remove_underscore,
|
||||
thresh,
|
||||
gr.Checkbox(value=use_shell, visible=False),
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue