mirror of https://github.com/bmaltais/kohya_ss
Update SDXL tickbox features
parent
1703e40f85
commit
d96a278006
|
|
@ -59,7 +59,8 @@ def save_configuration(
|
|||
file_path,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization, sdxl,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
logging_dir,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
|
|
@ -176,7 +177,8 @@ def open_configuration(
|
|||
file_path,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization, sdxl,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
logging_dir,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
|
|
@ -281,7 +283,8 @@ def train_model(
|
|||
print_only,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization, sdxl,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
logging_dir,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
|
|
@ -374,7 +377,7 @@ def train_model(
|
|||
msg='Image folder does not exist', headless=headless_bool
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
if not verify_image_folder_pattern(train_data_dir):
|
||||
return
|
||||
|
||||
|
|
@ -385,7 +388,7 @@ def train_model(
|
|||
headless=headless_bool,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
if not verify_image_folder_pattern(reg_data_dir):
|
||||
return
|
||||
|
||||
|
|
@ -399,9 +402,12 @@ def train_model(
|
|||
output_name, output_dir, save_model_as, headless=headless_bool
|
||||
):
|
||||
return
|
||||
|
||||
|
||||
if sdxl:
|
||||
output_message(msg='TI training is not compatible with an SDXL model.', headless=headless_bool)
|
||||
output_message(
|
||||
msg='TI training is not compatible with an SDXL model.',
|
||||
headless=headless_bool,
|
||||
)
|
||||
return
|
||||
|
||||
if optimizer == 'Adafactor' and lr_warmup != '0':
|
||||
|
|
|
|||
|
|
@ -51,7 +51,8 @@ def save_configuration(
|
|||
file_path,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization, sdxl,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
train_dir,
|
||||
image_folder,
|
||||
output_dir,
|
||||
|
|
@ -174,7 +175,8 @@ def open_configuration(
|
|||
file_path,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization, sdxl,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
train_dir,
|
||||
image_folder,
|
||||
output_dir,
|
||||
|
|
@ -285,7 +287,8 @@ def train_model(
|
|||
print_only,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization, sdxl,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
train_dir,
|
||||
image_folder,
|
||||
output_dir,
|
||||
|
|
@ -478,7 +481,7 @@ def train_model(
|
|||
run_cmd += f' "./sdxl_train.py"'
|
||||
else:
|
||||
run_cmd += f' "./fine_tune.py"'
|
||||
|
||||
|
||||
if v2:
|
||||
run_cmd += ' --v2'
|
||||
if v_parameterization:
|
||||
|
|
|
|||
16
kohya_gui.py
16
kohya_gui.py
|
|
@ -30,14 +30,14 @@ def UI(**kwargs):
|
|||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
log.info('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
|
||||
|
||||
if os.path.exists('./.release'):
|
||||
with open(os.path.join('./.release'), 'r', encoding='utf8') as file:
|
||||
release= file.read()
|
||||
|
||||
release = file.read()
|
||||
|
||||
if os.path.exists('./README.md'):
|
||||
with open(os.path.join('./README.md'), 'r', encoding='utf8') as file:
|
||||
README= file.read()
|
||||
README = file.read()
|
||||
|
||||
interface = gr.Blocks(
|
||||
css=css, title=f'Kohya_ss GUI {release}', theme=gr.themes.Default()
|
||||
|
|
@ -73,18 +73,18 @@ def UI(**kwargs):
|
|||
gradio_merge_lora_tab(headless=headless)
|
||||
gradio_merge_lycoris_tab(headless=headless)
|
||||
gradio_resize_lora_tab(headless=headless)
|
||||
with gr.Tab('About'):
|
||||
with gr.Tab('About'):
|
||||
gr.Markdown(f'kohya_ss GUI release {release}')
|
||||
with gr.Tab('README'):
|
||||
gr.Markdown(README)
|
||||
|
||||
htmlStr = f'''
|
||||
|
||||
htmlStr = f"""
|
||||
<html>
|
||||
<body>
|
||||
<div class="ver-class">{release}</div>
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
"""
|
||||
gr.HTML(htmlStr)
|
||||
# Show the interface
|
||||
launch_kwargs = {}
|
||||
|
|
|
|||
|
|
@ -35,8 +35,13 @@ V1_MODELS = [
|
|||
'runwayml/stable-diffusion-v1-5',
|
||||
]
|
||||
|
||||
# define a list of substrings to search for SDXL base models
|
||||
SDXL_MODELS = [
|
||||
'stabilityai/stable-diffusion-SDXL-base',
|
||||
]
|
||||
|
||||
# define a list of substrings to search for
|
||||
ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
|
||||
ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + SDXL_MODELS
|
||||
|
||||
ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID']
|
||||
|
||||
|
|
@ -479,7 +484,15 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name):
|
|||
def set_pretrained_model_name_or_path_input(
|
||||
model_list, pretrained_model_name_or_path, v2, v_parameterization, sdxl
|
||||
):
|
||||
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
||||
# Check if the given model_list is in the list of SDXL models
|
||||
if str(model_list) in SDXL_MODELS:
|
||||
log.info('SDXL model detected. Setting parameters')
|
||||
v2 = True
|
||||
v_parameterization = True
|
||||
sdxl = True
|
||||
pretrained_model_name_or_path = str(model_list)
|
||||
|
||||
# Check if the given model_list is in the list of V2 base models
|
||||
if str(model_list) in V2_BASE_MODELS:
|
||||
log.info('SD v2 model detected. Setting --v2 parameter')
|
||||
v2 = True
|
||||
|
|
@ -487,7 +500,7 @@ def set_pretrained_model_name_or_path_input(
|
|||
sdxl = False
|
||||
pretrained_model_name_or_path = str(model_list)
|
||||
|
||||
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
|
||||
# Check if the given model_list is in the list of V parameterization models
|
||||
if str(model_list) in V_PARAMETERIZATION_MODELS:
|
||||
log.info(
|
||||
'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
|
||||
|
|
@ -497,13 +510,16 @@ def set_pretrained_model_name_or_path_input(
|
|||
sdxl = False
|
||||
pretrained_model_name_or_path = str(model_list)
|
||||
|
||||
# Check if the given model_list is in the list of V1 models
|
||||
if str(model_list) in V1_MODELS:
|
||||
v2 = False
|
||||
v_parameterization = False
|
||||
sdxl = False
|
||||
pretrained_model_name_or_path = str(model_list)
|
||||
|
||||
# Check if the model_list is set to 'custom'
|
||||
if model_list == 'custom':
|
||||
# Check if the pretrained_model_name_or_path is in any of the model lists
|
||||
if (
|
||||
str(pretrained_model_name_or_path) in V1_MODELS
|
||||
or str(pretrained_model_name_or_path) in V2_BASE_MODELS
|
||||
|
|
@ -513,25 +529,58 @@ def set_pretrained_model_name_or_path_input(
|
|||
v2 = False
|
||||
v_parameterization = False
|
||||
sdxl = False
|
||||
return model_list, pretrained_model_name_or_path, v2, sdxl
|
||||
|
||||
# Return the updated variables
|
||||
return model_list, pretrained_model_name_or_path, v2, v_parameterization, sdxl
|
||||
|
||||
|
||||
def set_v2_checkbox(model_list, v2, v_parameterization):
|
||||
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
||||
def set_v2_checkbox(model_list, v2, v_parameterization, sdxl):
|
||||
|
||||
if str(model_list) in SDXL_MODELS:
|
||||
if not v2:
|
||||
log.info(f'v2 can\'t be deselected because this {str(model_list)} is considered a v2 model...')
|
||||
v2 = True
|
||||
if not v_parameterization:
|
||||
log.info(f'v_parameterization can\'t be deselected because {str(model_list)} require v parameterization...')
|
||||
v_parameterization = True
|
||||
if not sdxl:
|
||||
log.info(f'sdxl can\'t be deselected because {str(model_list)} is an sdxl model...')
|
||||
sdxl = True
|
||||
|
||||
if str(model_list) in V2_BASE_MODELS:
|
||||
v2 = True
|
||||
v_parameterization = False
|
||||
if not v2:
|
||||
log.info(f'v2 can\'t be deselected because this {str(model_list)} is a v2 model...')
|
||||
v2 = True
|
||||
if v_parameterization:
|
||||
log.info(f'v_parameterization can\'t be selected because {str(model_list)} does not support v parameterization...')
|
||||
v_parameterization = False
|
||||
if sdxl:
|
||||
log.info(f'sdxl can\'t be selected because {str(model_list)} is not an sdxl model...')
|
||||
sdxl = False
|
||||
|
||||
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
|
||||
if str(model_list) in V_PARAMETERIZATION_MODELS:
|
||||
v2 = True
|
||||
v_parameterization = True
|
||||
if not v2:
|
||||
log.info(f'v2 can\'t be deselected because this {str(model_list)} supports v parameterization...')
|
||||
v2 = True
|
||||
if not v_parameterization:
|
||||
log.info(f'v_parameterization can\'t be deselected because {str(model_list)} supports v parameterization...')
|
||||
v_parameterization = True
|
||||
if sdxl:
|
||||
log.info(f'sdxl can\'t be selected because {str(model_list)} is not an sdxl model...')
|
||||
sdxl = False
|
||||
|
||||
if str(model_list) in V1_MODELS:
|
||||
v2 = False
|
||||
v_parameterization = False
|
||||
if v2:
|
||||
log.info(f'v2 can\'t be selected because this {str(model_list)} is a v1 model...')
|
||||
v2 = False
|
||||
if v_parameterization:
|
||||
log.info(f'v_parameterization can\'t be selected because {str(model_list)} is a v1 model...')
|
||||
v_parameterization = False
|
||||
if sdxl:
|
||||
log.info(f'sdxl can\'t be selected because {str(model_list)} is not an sdxl model...')
|
||||
sdxl = False
|
||||
|
||||
return v2, v_parameterization
|
||||
return v2, v_parameterization, sdxl
|
||||
|
||||
|
||||
def set_model_list(
|
||||
|
|
@ -665,14 +714,20 @@ def gradio_source_model(
|
|||
)
|
||||
v2.change(
|
||||
set_v2_checkbox,
|
||||
inputs=[model_list, v2, v_parameterization],
|
||||
outputs=[v2, v_parameterization],
|
||||
inputs=[model_list, v2, v_parameterization, sdxl],
|
||||
outputs=[v2, v_parameterization, sdxl],
|
||||
show_progress=False,
|
||||
)
|
||||
v_parameterization.change(
|
||||
set_v2_checkbox,
|
||||
inputs=[model_list, v2, v_parameterization],
|
||||
outputs=[v2, v_parameterization],
|
||||
inputs=[model_list, v2, v_parameterization, sdxl],
|
||||
outputs=[v2, v_parameterization, sdxl],
|
||||
show_progress=False,
|
||||
)
|
||||
sdxl.change(
|
||||
set_v2_checkbox,
|
||||
inputs=[model_list, v2, v_parameterization, sdxl],
|
||||
outputs=[v2, v_parameterization, sdxl],
|
||||
show_progress=False,
|
||||
)
|
||||
model_list.change(
|
||||
|
|
|
|||
52
lora_gui.py
52
lora_gui.py
|
|
@ -69,7 +69,8 @@ def save_configuration(
|
|||
file_path,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization, sdxl,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
logging_dir,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
|
|
@ -215,7 +216,8 @@ def open_configuration(
|
|||
file_path,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization, sdxl,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
logging_dir,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
|
|
@ -313,14 +315,14 @@ def open_configuration(
|
|||
network_dropout,
|
||||
rank_dropout,
|
||||
module_dropout,
|
||||
training_preset
|
||||
training_preset,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
||||
ask_for_file = True if ask_for_file.get('label') == 'True' else False
|
||||
apply_preset = True if apply_preset.get('label') == 'True' else False
|
||||
|
||||
|
||||
# Check if we are "applying" a preset or a config
|
||||
if apply_preset:
|
||||
log.info(f'Applying preset {training_preset}...')
|
||||
|
|
@ -328,10 +330,12 @@ def open_configuration(
|
|||
else:
|
||||
# If not applying a preset, set the `training_preset` field to an empty string
|
||||
# Find the index of the `training_preset` parameter using the `index()` method
|
||||
training_preset_index = parameters.index(("training_preset", training_preset))
|
||||
training_preset_index = parameters.index(
|
||||
('training_preset', training_preset)
|
||||
)
|
||||
|
||||
# Update the value of `training_preset` by directly assigning an empty string value
|
||||
parameters[training_preset_index] = ("training_preset", "")
|
||||
parameters[training_preset_index] = ('training_preset', '')
|
||||
|
||||
original_file_path = file_path
|
||||
|
||||
|
|
@ -376,7 +380,8 @@ def train_model(
|
|||
print_only,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization, sdxl,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
logging_dir,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
|
|
@ -496,7 +501,7 @@ def train_model(
|
|||
msg='Image folder does not exist', headless=headless_bool
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
if not verify_image_folder_pattern(train_data_dir):
|
||||
return
|
||||
|
||||
|
|
@ -507,7 +512,7 @@ def train_model(
|
|||
headless=headless_bool,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
if not verify_image_folder_pattern(reg_data_dir):
|
||||
return
|
||||
|
||||
|
|
@ -1095,15 +1100,18 @@ def lora_tab(
|
|||
outputs=[logging_dir],
|
||||
)
|
||||
with gr.Tab('Training parameters'):
|
||||
|
||||
def list_presets(path):
|
||||
json_files = []
|
||||
for file in os.listdir(path):
|
||||
if file.endswith(".json"):
|
||||
if file.endswith('.json'):
|
||||
json_files.append(os.path.splitext(file)[0])
|
||||
return json_files
|
||||
|
||||
training_preset = gr.Dropdown(
|
||||
label='Presets',
|
||||
choices=list_presets('./presets/lora'), elem_id="myDropdown"
|
||||
choices=list_presets('./presets/lora'),
|
||||
elem_id='myDropdown',
|
||||
)
|
||||
with gr.Row():
|
||||
LoRA_type = gr.Dropdown(
|
||||
|
|
@ -1695,21 +1703,31 @@ def lora_tab(
|
|||
|
||||
button_open_config.click(
|
||||
open_configuration,
|
||||
inputs=[dummy_db_true, dummy_db_false, config_file_name] + settings_list + [training_preset],
|
||||
outputs=[config_file_name] + settings_list + [training_preset, LoCon_row],
|
||||
inputs=[dummy_db_true, dummy_db_false, config_file_name]
|
||||
+ settings_list
|
||||
+ [training_preset],
|
||||
outputs=[config_file_name]
|
||||
+ settings_list
|
||||
+ [training_preset, LoCon_row],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
button_load_config.click(
|
||||
open_configuration,
|
||||
inputs=[dummy_db_false, dummy_db_false, config_file_name] + settings_list + [training_preset],
|
||||
outputs=[config_file_name] + settings_list + [training_preset, LoCon_row],
|
||||
inputs=[dummy_db_false, dummy_db_false, config_file_name]
|
||||
+ settings_list
|
||||
+ [training_preset],
|
||||
outputs=[config_file_name]
|
||||
+ settings_list
|
||||
+ [training_preset, LoCon_row],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
||||
training_preset.input(
|
||||
open_configuration,
|
||||
inputs=[dummy_db_false, dummy_db_true, config_file_name] + settings_list + [training_preset],
|
||||
inputs=[dummy_db_false, dummy_db_true, config_file_name]
|
||||
+ settings_list
|
||||
+ [training_preset],
|
||||
outputs=[gr.Textbox()] + settings_list + [training_preset, LoCon_row],
|
||||
show_progress=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -59,7 +59,8 @@ def save_configuration(
|
|||
file_path,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization, sdxl,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
logging_dir,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
|
|
@ -181,7 +182,8 @@ def open_configuration(
|
|||
file_path,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization, sdxl,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
logging_dir,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
|
|
@ -291,7 +293,8 @@ def train_model(
|
|||
print_only,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization, sdxl,
|
||||
v_parameterization,
|
||||
sdxl,
|
||||
logging_dir,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
|
|
@ -389,7 +392,7 @@ def train_model(
|
|||
msg='Image folder does not exist', headless=headless_bool
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
if not verify_image_folder_pattern(train_data_dir):
|
||||
return
|
||||
|
||||
|
|
@ -400,7 +403,7 @@ def train_model(
|
|||
headless=headless_bool,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
if not verify_image_folder_pattern(reg_data_dir):
|
||||
return
|
||||
|
||||
|
|
@ -425,9 +428,12 @@ def train_model(
|
|||
output_name, output_dir, save_model_as, headless_bool
|
||||
):
|
||||
return
|
||||
|
||||
|
||||
if sdxl:
|
||||
output_message(msg='TI training is not compatible with an SDXL model.', headless=headless_bool)
|
||||
output_message(
|
||||
msg='TI training is not compatible with an SDXL model.',
|
||||
headless=headless_bool,
|
||||
)
|
||||
return
|
||||
|
||||
# if float(noise_offset) > 0 and (
|
||||
|
|
|
|||
Loading…
Reference in New Issue