Update SDXL tickbox features

pull/1133/head
bmaltais 2023-06-30 20:28:59 -04:00
parent 1703e40f85
commit d96a278006
6 changed files with 149 additions and 61 deletions

View File

@ -59,7 +59,8 @@ def save_configuration(
file_path,
pretrained_model_name_or_path,
v2,
v_parameterization, sdxl,
v_parameterization,
sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@ -176,7 +177,8 @@ def open_configuration(
file_path,
pretrained_model_name_or_path,
v2,
v_parameterization, sdxl,
v_parameterization,
sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@ -281,7 +283,8 @@ def train_model(
print_only,
pretrained_model_name_or_path,
v2,
v_parameterization, sdxl,
v_parameterization,
sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@ -374,7 +377,7 @@ def train_model(
msg='Image folder does not exist', headless=headless_bool
)
return
if not verify_image_folder_pattern(train_data_dir):
return
@ -385,7 +388,7 @@ def train_model(
headless=headless_bool,
)
return
if not verify_image_folder_pattern(reg_data_dir):
return
@ -399,9 +402,12 @@ def train_model(
output_name, output_dir, save_model_as, headless=headless_bool
):
return
if sdxl:
output_message(msg='TI training is not compatible with an SDXL model.', headless=headless_bool)
output_message(
msg='TI training is not compatible with an SDXL model.',
headless=headless_bool,
)
return
if optimizer == 'Adafactor' and lr_warmup != '0':

View File

@ -51,7 +51,8 @@ def save_configuration(
file_path,
pretrained_model_name_or_path,
v2,
v_parameterization, sdxl,
v_parameterization,
sdxl,
train_dir,
image_folder,
output_dir,
@ -174,7 +175,8 @@ def open_configuration(
file_path,
pretrained_model_name_or_path,
v2,
v_parameterization, sdxl,
v_parameterization,
sdxl,
train_dir,
image_folder,
output_dir,
@ -285,7 +287,8 @@ def train_model(
print_only,
pretrained_model_name_or_path,
v2,
v_parameterization, sdxl,
v_parameterization,
sdxl,
train_dir,
image_folder,
output_dir,
@ -478,7 +481,7 @@ def train_model(
run_cmd += f' "./sdxl_train.py"'
else:
run_cmd += f' "./fine_tune.py"'
if v2:
run_cmd += ' --v2'
if v_parameterization:

View File

@ -30,14 +30,14 @@ def UI(**kwargs):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
log.info('Load CSS...')
css += file.read() + '\n'
if os.path.exists('./.release'):
with open(os.path.join('./.release'), 'r', encoding='utf8') as file:
release= file.read()
release = file.read()
if os.path.exists('./README.md'):
with open(os.path.join('./README.md'), 'r', encoding='utf8') as file:
README= file.read()
README = file.read()
interface = gr.Blocks(
css=css, title=f'Kohya_ss GUI {release}', theme=gr.themes.Default()
@ -73,18 +73,18 @@ def UI(**kwargs):
gradio_merge_lora_tab(headless=headless)
gradio_merge_lycoris_tab(headless=headless)
gradio_resize_lora_tab(headless=headless)
with gr.Tab('About'):
with gr.Tab('About'):
gr.Markdown(f'kohya_ss GUI release {release}')
with gr.Tab('README'):
gr.Markdown(README)
htmlStr = f'''
htmlStr = f"""
<html>
<body>
<div class="ver-class">{release}</div>
</body>
</html>
'''
"""
gr.HTML(htmlStr)
# Show the interface
launch_kwargs = {}

View File

@ -35,8 +35,13 @@ V1_MODELS = [
'runwayml/stable-diffusion-v1-5',
]
# define a list of substrings to search for SDXL base models
SDXL_MODELS = [
'stabilityai/stable-diffusion-SDXL-base',
]
# define a list of substrings to search for
ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + SDXL_MODELS
ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID']
@ -479,7 +484,15 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name):
def set_pretrained_model_name_or_path_input(
model_list, pretrained_model_name_or_path, v2, v_parameterization, sdxl
):
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
# Check if the given model_list is in the list of SDXL models
if str(model_list) in SDXL_MODELS:
log.info('SDXL model detected. Setting parameters')
v2 = True
v_parameterization = True
sdxl = True
pretrained_model_name_or_path = str(model_list)
# Check if the given model_list is in the list of V2 base models
if str(model_list) in V2_BASE_MODELS:
log.info('SD v2 model detected. Setting --v2 parameter')
v2 = True
@ -487,7 +500,7 @@ def set_pretrained_model_name_or_path_input(
sdxl = False
pretrained_model_name_or_path = str(model_list)
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
# Check if the given model_list is in the list of V parameterization models
if str(model_list) in V_PARAMETERIZATION_MODELS:
log.info(
'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
@ -497,13 +510,16 @@ def set_pretrained_model_name_or_path_input(
sdxl = False
pretrained_model_name_or_path = str(model_list)
# Check if the given model_list is in the list of V1 models
if str(model_list) in V1_MODELS:
v2 = False
v_parameterization = False
sdxl = False
pretrained_model_name_or_path = str(model_list)
# Check if the model_list is set to 'custom'
if model_list == 'custom':
# Check if the pretrained_model_name_or_path is in any of the model lists
if (
str(pretrained_model_name_or_path) in V1_MODELS
or str(pretrained_model_name_or_path) in V2_BASE_MODELS
@ -513,25 +529,58 @@ def set_pretrained_model_name_or_path_input(
v2 = False
v_parameterization = False
sdxl = False
return model_list, pretrained_model_name_or_path, v2, sdxl
# Return the updated variables
return model_list, pretrained_model_name_or_path, v2, v_parameterization, sdxl
def set_v2_checkbox(model_list, v2, v_parameterization):
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
def set_v2_checkbox(model_list, v2, v_parameterization, sdxl):
if str(model_list) in SDXL_MODELS:
if not v2:
log.info(f'v2 can\'t be deselected because this {str(model_list)} is considered a v2 model...')
v2 = True
if not v_parameterization:
log.info(f'v_parameterization can\'t be deselected because {str(model_list)} require v parameterization...')
v_parameterization = True
if not sdxl:
log.info(f'sdxl can\'t be deselected because {str(model_list)} is an sdxl model...')
sdxl = True
if str(model_list) in V2_BASE_MODELS:
v2 = True
v_parameterization = False
if not v2:
log.info(f'v2 can\'t be deselected because this {str(model_list)} is a v2 model...')
v2 = True
if v_parameterization:
log.info(f'v_parameterization can\'t be selected because {str(model_list)} does not support v parameterization...')
v_parameterization = False
if sdxl:
log.info(f'sdxl can\'t be selected because {str(model_list)} is not an sdxl model...')
sdxl = False
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
if str(model_list) in V_PARAMETERIZATION_MODELS:
v2 = True
v_parameterization = True
if not v2:
log.info(f'v2 can\'t be deselected because this {str(model_list)} supports v parameterization...')
v2 = True
if not v_parameterization:
log.info(f'v_parameterization can\'t be deselected because {str(model_list)} supports v parameterization...')
v_parameterization = True
if sdxl:
log.info(f'sdxl can\'t be selected because {str(model_list)} is not an sdxl model...')
sdxl = False
if str(model_list) in V1_MODELS:
v2 = False
v_parameterization = False
if v2:
log.info(f'v2 can\'t be selected because this {str(model_list)} is a v1 model...')
v2 = False
if v_parameterization:
log.info(f'v_parameterization can\'t be selected because {str(model_list)} is a v1 model...')
v_parameterization = False
if sdxl:
log.info(f'sdxl can\'t be selected because {str(model_list)} is not an sdxl model...')
sdxl = False
return v2, v_parameterization
return v2, v_parameterization, sdxl
def set_model_list(
@ -665,14 +714,20 @@ def gradio_source_model(
)
v2.change(
set_v2_checkbox,
inputs=[model_list, v2, v_parameterization],
outputs=[v2, v_parameterization],
inputs=[model_list, v2, v_parameterization, sdxl],
outputs=[v2, v_parameterization, sdxl],
show_progress=False,
)
v_parameterization.change(
set_v2_checkbox,
inputs=[model_list, v2, v_parameterization],
outputs=[v2, v_parameterization],
inputs=[model_list, v2, v_parameterization, sdxl],
outputs=[v2, v_parameterization, sdxl],
show_progress=False,
)
sdxl.change(
set_v2_checkbox,
inputs=[model_list, v2, v_parameterization, sdxl],
outputs=[v2, v_parameterization, sdxl],
show_progress=False,
)
model_list.change(

View File

@ -69,7 +69,8 @@ def save_configuration(
file_path,
pretrained_model_name_or_path,
v2,
v_parameterization, sdxl,
v_parameterization,
sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@ -215,7 +216,8 @@ def open_configuration(
file_path,
pretrained_model_name_or_path,
v2,
v_parameterization, sdxl,
v_parameterization,
sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@ -313,14 +315,14 @@ def open_configuration(
network_dropout,
rank_dropout,
module_dropout,
training_preset
training_preset,
):
# Get list of function parameters and values
parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False
apply_preset = True if apply_preset.get('label') == 'True' else False
# Check if we are "applying" a preset or a config
if apply_preset:
log.info(f'Applying preset {training_preset}...')
@ -328,10 +330,12 @@ def open_configuration(
else:
# If not applying a preset, set the `training_preset` field to an empty string
# Find the index of the `training_preset` parameter using the `index()` method
training_preset_index = parameters.index(("training_preset", training_preset))
training_preset_index = parameters.index(
('training_preset', training_preset)
)
# Update the value of `training_preset` by directly assigning an empty string value
parameters[training_preset_index] = ("training_preset", "")
parameters[training_preset_index] = ('training_preset', '')
original_file_path = file_path
@ -376,7 +380,8 @@ def train_model(
print_only,
pretrained_model_name_or_path,
v2,
v_parameterization, sdxl,
v_parameterization,
sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@ -496,7 +501,7 @@ def train_model(
msg='Image folder does not exist', headless=headless_bool
)
return
if not verify_image_folder_pattern(train_data_dir):
return
@ -507,7 +512,7 @@ def train_model(
headless=headless_bool,
)
return
if not verify_image_folder_pattern(reg_data_dir):
return
@ -1095,15 +1100,18 @@ def lora_tab(
outputs=[logging_dir],
)
with gr.Tab('Training parameters'):
def list_presets(path):
json_files = []
for file in os.listdir(path):
if file.endswith(".json"):
if file.endswith('.json'):
json_files.append(os.path.splitext(file)[0])
return json_files
training_preset = gr.Dropdown(
label='Presets',
choices=list_presets('./presets/lora'), elem_id="myDropdown"
choices=list_presets('./presets/lora'),
elem_id='myDropdown',
)
with gr.Row():
LoRA_type = gr.Dropdown(
@ -1695,21 +1703,31 @@ def lora_tab(
button_open_config.click(
open_configuration,
inputs=[dummy_db_true, dummy_db_false, config_file_name] + settings_list + [training_preset],
outputs=[config_file_name] + settings_list + [training_preset, LoCon_row],
inputs=[dummy_db_true, dummy_db_false, config_file_name]
+ settings_list
+ [training_preset],
outputs=[config_file_name]
+ settings_list
+ [training_preset, LoCon_row],
show_progress=False,
)
button_load_config.click(
open_configuration,
inputs=[dummy_db_false, dummy_db_false, config_file_name] + settings_list + [training_preset],
outputs=[config_file_name] + settings_list + [training_preset, LoCon_row],
inputs=[dummy_db_false, dummy_db_false, config_file_name]
+ settings_list
+ [training_preset],
outputs=[config_file_name]
+ settings_list
+ [training_preset, LoCon_row],
show_progress=False,
)
training_preset.input(
open_configuration,
inputs=[dummy_db_false, dummy_db_true, config_file_name] + settings_list + [training_preset],
inputs=[dummy_db_false, dummy_db_true, config_file_name]
+ settings_list
+ [training_preset],
outputs=[gr.Textbox()] + settings_list + [training_preset, LoCon_row],
show_progress=False,
)

View File

@ -59,7 +59,8 @@ def save_configuration(
file_path,
pretrained_model_name_or_path,
v2,
v_parameterization, sdxl,
v_parameterization,
sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@ -181,7 +182,8 @@ def open_configuration(
file_path,
pretrained_model_name_or_path,
v2,
v_parameterization, sdxl,
v_parameterization,
sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@ -291,7 +293,8 @@ def train_model(
print_only,
pretrained_model_name_or_path,
v2,
v_parameterization, sdxl,
v_parameterization,
sdxl,
logging_dir,
train_data_dir,
reg_data_dir,
@ -389,7 +392,7 @@ def train_model(
msg='Image folder does not exist', headless=headless_bool
)
return
if not verify_image_folder_pattern(train_data_dir):
return
@ -400,7 +403,7 @@ def train_model(
headless=headless_bool,
)
return
if not verify_image_folder_pattern(reg_data_dir):
return
@ -425,9 +428,12 @@ def train_model(
output_name, output_dir, save_model_as, headless_bool
):
return
if sdxl:
output_message(msg='TI training is not compatible with an SDXL model.', headless=headless_bool)
output_message(
msg='TI training is not compatible with an SDXL model.',
headless=headless_bool,
)
return
# if float(noise_offset) > 0 and (