diff --git a/dreambooth_gui.py b/dreambooth_gui.py
index 4cb7e77..c22e418 100644
--- a/dreambooth_gui.py
+++ b/dreambooth_gui.py
@@ -59,7 +59,8 @@ def save_configuration(
file_path,
pretrained_model_name_or_path,
v2,
- v_parameterization, sdxl,
+ v_parameterization,
+ sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@@ -176,7 +177,8 @@ def open_configuration(
file_path,
pretrained_model_name_or_path,
v2,
- v_parameterization, sdxl,
+ v_parameterization,
+ sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@@ -281,7 +283,8 @@ def train_model(
print_only,
pretrained_model_name_or_path,
v2,
- v_parameterization, sdxl,
+ v_parameterization,
+ sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@@ -374,7 +377,7 @@ def train_model(
msg='Image folder does not exist', headless=headless_bool
)
return
-
+
if not verify_image_folder_pattern(train_data_dir):
return
@@ -385,7 +388,7 @@ def train_model(
headless=headless_bool,
)
return
-
+
if not verify_image_folder_pattern(reg_data_dir):
return
@@ -399,9 +402,12 @@ def train_model(
output_name, output_dir, save_model_as, headless=headless_bool
):
return
-
+
if sdxl:
- output_message(msg='TI training is not compatible with an SDXL model.', headless=headless_bool)
+ output_message(
+ msg='TI training is not compatible with an SDXL model.',
+ headless=headless_bool,
+ )
return
if optimizer == 'Adafactor' and lr_warmup != '0':
diff --git a/finetune_gui.py b/finetune_gui.py
index 6eab933..2d5d24a 100644
--- a/finetune_gui.py
+++ b/finetune_gui.py
@@ -51,7 +51,8 @@ def save_configuration(
file_path,
pretrained_model_name_or_path,
v2,
- v_parameterization, sdxl,
+ v_parameterization,
+ sdxl,
train_dir,
image_folder,
output_dir,
@@ -174,7 +175,8 @@ def open_configuration(
file_path,
pretrained_model_name_or_path,
v2,
- v_parameterization, sdxl,
+ v_parameterization,
+ sdxl,
train_dir,
image_folder,
output_dir,
@@ -285,7 +287,8 @@ def train_model(
print_only,
pretrained_model_name_or_path,
v2,
- v_parameterization, sdxl,
+ v_parameterization,
+ sdxl,
train_dir,
image_folder,
output_dir,
@@ -478,7 +481,7 @@ def train_model(
run_cmd += f' "./sdxl_train.py"'
else:
run_cmd += f' "./fine_tune.py"'
-
+
if v2:
run_cmd += ' --v2'
if v_parameterization:
diff --git a/kohya_gui.py b/kohya_gui.py
index c2e095e..5c69951 100644
--- a/kohya_gui.py
+++ b/kohya_gui.py
@@ -30,14 +30,14 @@ def UI(**kwargs):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
log.info('Load CSS...')
css += file.read() + '\n'
-
+
if os.path.exists('./.release'):
with open(os.path.join('./.release'), 'r', encoding='utf8') as file:
- release= file.read()
-
+ release = file.read()
+
if os.path.exists('./README.md'):
with open(os.path.join('./README.md'), 'r', encoding='utf8') as file:
- README= file.read()
+ README = file.read()
interface = gr.Blocks(
css=css, title=f'Kohya_ss GUI {release}', theme=gr.themes.Default()
@@ -73,18 +73,18 @@ def UI(**kwargs):
gradio_merge_lora_tab(headless=headless)
gradio_merge_lycoris_tab(headless=headless)
gradio_resize_lora_tab(headless=headless)
- with gr.Tab('About'):
+ with gr.Tab('About'):
gr.Markdown(f'kohya_ss GUI release {release}')
with gr.Tab('README'):
gr.Markdown(README)
-
- htmlStr = f'''
+
+ htmlStr = f"""
{release}
- '''
+ """
gr.HTML(htmlStr)
# Show the interface
launch_kwargs = {}
diff --git a/library/common_gui.py b/library/common_gui.py
index 9e135ec..9dd5c8b 100644
--- a/library/common_gui.py
+++ b/library/common_gui.py
@@ -35,8 +35,13 @@ V1_MODELS = [
'runwayml/stable-diffusion-v1-5',
]
+# define a list of substrings to search for SDXL base models
+SDXL_MODELS = [
+ 'stabilityai/stable-diffusion-SDXL-base',
+]
+
# define a list of substrings to search for
-ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
+ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + SDXL_MODELS
ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID']
@@ -479,7 +484,15 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name):
def set_pretrained_model_name_or_path_input(
model_list, pretrained_model_name_or_path, v2, v_parameterization, sdxl
):
- # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
+ # Check if the given model_list is in the list of SDXL models
+ if str(model_list) in SDXL_MODELS:
+ log.info('SDXL model detected. Setting parameters')
+ v2 = True
+ v_parameterization = True
+ sdxl = True
+ pretrained_model_name_or_path = str(model_list)
+
+ # Check if the given model_list is in the list of V2 base models
if str(model_list) in V2_BASE_MODELS:
log.info('SD v2 model detected. Setting --v2 parameter')
v2 = True
@@ -487,7 +500,7 @@ def set_pretrained_model_name_or_path_input(
sdxl = False
pretrained_model_name_or_path = str(model_list)
- # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
+ # Check if the given model_list is in the list of V parameterization models
if str(model_list) in V_PARAMETERIZATION_MODELS:
log.info(
'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
@@ -497,13 +510,16 @@ def set_pretrained_model_name_or_path_input(
sdxl = False
pretrained_model_name_or_path = str(model_list)
+ # Check if the given model_list is in the list of V1 models
if str(model_list) in V1_MODELS:
v2 = False
v_parameterization = False
sdxl = False
pretrained_model_name_or_path = str(model_list)
+ # Check if the model_list is set to 'custom'
if model_list == 'custom':
+ # Check if the pretrained_model_name_or_path is in any of the model lists
if (
str(pretrained_model_name_or_path) in V1_MODELS
or str(pretrained_model_name_or_path) in V2_BASE_MODELS
@@ -513,25 +529,58 @@ def set_pretrained_model_name_or_path_input(
v2 = False
v_parameterization = False
sdxl = False
- return model_list, pretrained_model_name_or_path, v2, sdxl
+
+ # Return the updated variables
+ return model_list, pretrained_model_name_or_path, v2, v_parameterization, sdxl
-def set_v2_checkbox(model_list, v2, v_parameterization):
- # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
+def set_v2_checkbox(model_list, v2, v_parameterization, sdxl):
+
+ if str(model_list) in SDXL_MODELS:
+ if not v2:
+ log.info(f'v2 can\'t be deselected because this {str(model_list)} is considered a v2 model...')
+ v2 = True
+ if not v_parameterization:
+ log.info(f'v_parameterization can\'t be deselected because {str(model_list)} require v parameterization...')
+ v_parameterization = True
+ if not sdxl:
+ log.info(f'sdxl can\'t be deselected because {str(model_list)} is an sdxl model...')
+ sdxl = True
+
if str(model_list) in V2_BASE_MODELS:
- v2 = True
- v_parameterization = False
+ if not v2:
+ log.info(f'v2 can\'t be deselected because this {str(model_list)} is a v2 model...')
+ v2 = True
+ if v_parameterization:
+ log.info(f'v_parameterization can\'t be selected because {str(model_list)} does not support v parameterization...')
+ v_parameterization = False
+ if sdxl:
+ log.info(f'sdxl can\'t be selected because {str(model_list)} is not an sdxl model...')
+ sdxl = False
- # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
if str(model_list) in V_PARAMETERIZATION_MODELS:
- v2 = True
- v_parameterization = True
+ if not v2:
+ log.info(f'v2 can\'t be deselected because this {str(model_list)} supports v parameterization...')
+ v2 = True
+ if not v_parameterization:
+ log.info(f'v_parameterization can\'t be deselected because {str(model_list)} supports v parameterization...')
+ v_parameterization = True
+ if sdxl:
+ log.info(f'sdxl can\'t be selected because {str(model_list)} is not an sdxl model...')
+ sdxl = False
if str(model_list) in V1_MODELS:
- v2 = False
- v_parameterization = False
+ if v2:
+ log.info(f'v2 can\'t be selected because this {str(model_list)} is a v1 model...')
+ v2 = False
+ if v_parameterization:
+ log.info(f'v_parameterization can\'t be selected because {str(model_list)} is a v1 model...')
+ v_parameterization = False
+ if sdxl:
+ log.info(f'sdxl can\'t be selected because {str(model_list)} is not an sdxl model...')
+ sdxl = False
- return v2, v_parameterization
+ return v2, v_parameterization, sdxl
def set_model_list(
@@ -665,14 +714,20 @@ def gradio_source_model(
)
v2.change(
set_v2_checkbox,
- inputs=[model_list, v2, v_parameterization],
- outputs=[v2, v_parameterization],
+ inputs=[model_list, v2, v_parameterization, sdxl],
+ outputs=[v2, v_parameterization, sdxl],
show_progress=False,
)
v_parameterization.change(
set_v2_checkbox,
- inputs=[model_list, v2, v_parameterization],
- outputs=[v2, v_parameterization],
+ inputs=[model_list, v2, v_parameterization, sdxl],
+ outputs=[v2, v_parameterization, sdxl],
+ show_progress=False,
+ )
+ sdxl.change(
+ set_v2_checkbox,
+ inputs=[model_list, v2, v_parameterization, sdxl],
+ outputs=[v2, v_parameterization, sdxl],
show_progress=False,
)
model_list.change(
diff --git a/lora_gui.py b/lora_gui.py
index 4c80075..3563211 100644
--- a/lora_gui.py
+++ b/lora_gui.py
@@ -69,7 +69,8 @@ def save_configuration(
file_path,
pretrained_model_name_or_path,
v2,
- v_parameterization, sdxl,
+ v_parameterization,
+ sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@@ -215,7 +216,8 @@ def open_configuration(
file_path,
pretrained_model_name_or_path,
v2,
- v_parameterization, sdxl,
+ v_parameterization,
+ sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@@ -313,14 +315,14 @@ def open_configuration(
network_dropout,
rank_dropout,
module_dropout,
- training_preset
+ training_preset,
):
# Get list of function parameters and values
parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False
apply_preset = True if apply_preset.get('label') == 'True' else False
-
+
# Check if we are "applying" a preset or a config
if apply_preset:
log.info(f'Applying preset {training_preset}...')
@@ -328,10 +330,12 @@ def open_configuration(
else:
# If not applying a preset, set the `training_preset` field to an empty string
# Find the index of the `training_preset` parameter using the `index()` method
- training_preset_index = parameters.index(("training_preset", training_preset))
+ training_preset_index = parameters.index(
+ ('training_preset', training_preset)
+ )
# Update the value of `training_preset` by directly assigning an empty string value
- parameters[training_preset_index] = ("training_preset", "")
+ parameters[training_preset_index] = ('training_preset', '')
original_file_path = file_path
@@ -376,7 +380,8 @@ def train_model(
print_only,
pretrained_model_name_or_path,
v2,
- v_parameterization, sdxl,
+ v_parameterization,
+ sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@@ -496,7 +501,7 @@ def train_model(
msg='Image folder does not exist', headless=headless_bool
)
return
-
+
if not verify_image_folder_pattern(train_data_dir):
return
@@ -507,7 +512,7 @@ def train_model(
headless=headless_bool,
)
return
-
+
if not verify_image_folder_pattern(reg_data_dir):
return
@@ -1095,15 +1100,18 @@ def lora_tab(
outputs=[logging_dir],
)
with gr.Tab('Training parameters'):
+
def list_presets(path):
json_files = []
for file in os.listdir(path):
- if file.endswith(".json"):
+ if file.endswith('.json'):
json_files.append(os.path.splitext(file)[0])
return json_files
+
training_preset = gr.Dropdown(
label='Presets',
- choices=list_presets('./presets/lora'), elem_id="myDropdown"
+ choices=list_presets('./presets/lora'),
+ elem_id='myDropdown',
)
with gr.Row():
LoRA_type = gr.Dropdown(
@@ -1695,21 +1703,31 @@ def lora_tab(
button_open_config.click(
open_configuration,
- inputs=[dummy_db_true, dummy_db_false, config_file_name] + settings_list + [training_preset],
- outputs=[config_file_name] + settings_list + [training_preset, LoCon_row],
+ inputs=[dummy_db_true, dummy_db_false, config_file_name]
+ + settings_list
+ + [training_preset],
+ outputs=[config_file_name]
+ + settings_list
+ + [training_preset, LoCon_row],
show_progress=False,
)
button_load_config.click(
open_configuration,
- inputs=[dummy_db_false, dummy_db_false, config_file_name] + settings_list + [training_preset],
- outputs=[config_file_name] + settings_list + [training_preset, LoCon_row],
+ inputs=[dummy_db_false, dummy_db_false, config_file_name]
+ + settings_list
+ + [training_preset],
+ outputs=[config_file_name]
+ + settings_list
+ + [training_preset, LoCon_row],
show_progress=False,
)
-
+
training_preset.input(
open_configuration,
- inputs=[dummy_db_false, dummy_db_true, config_file_name] + settings_list + [training_preset],
+ inputs=[dummy_db_false, dummy_db_true, config_file_name]
+ + settings_list
+ + [training_preset],
outputs=[gr.Textbox()] + settings_list + [training_preset, LoCon_row],
show_progress=False,
)
diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py
index 13622df..5e0e89c 100644
--- a/textual_inversion_gui.py
+++ b/textual_inversion_gui.py
@@ -59,7 +59,8 @@ def save_configuration(
file_path,
pretrained_model_name_or_path,
v2,
- v_parameterization, sdxl,
+ v_parameterization,
+ sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@@ -181,7 +182,8 @@ def open_configuration(
file_path,
pretrained_model_name_or_path,
v2,
- v_parameterization, sdxl,
+ v_parameterization,
+ sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@@ -291,7 +293,8 @@ def train_model(
print_only,
pretrained_model_name_or_path,
v2,
- v_parameterization, sdxl,
+ v_parameterization,
+ sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@@ -389,7 +392,7 @@ def train_model(
msg='Image folder does not exist', headless=headless_bool
)
return
-
+
if not verify_image_folder_pattern(train_data_dir):
return
@@ -400,7 +403,7 @@ def train_model(
headless=headless_bool,
)
return
-
+
if not verify_image_folder_pattern(reg_data_dir):
return
@@ -425,9 +428,12 @@ def train_model(
output_name, output_dir, save_model_as, headless_bool
):
return
-
+
if sdxl:
- output_message(msg='TI training is not compatible with an SDXL model.', headless=headless_bool)
+ output_message(
+ msg='TI training is not compatible with an SDXL model.',
+ headless=headless_bool,
+ )
return
# if float(noise_offset) > 0 and (