diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 4cb7e77..c22e418 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -59,7 +59,8 @@ def save_configuration( file_path, pretrained_model_name_or_path, v2, - v_parameterization, sdxl, + v_parameterization, + sdxl, logging_dir, train_data_dir, reg_data_dir, @@ -176,7 +177,8 @@ def open_configuration( file_path, pretrained_model_name_or_path, v2, - v_parameterization, sdxl, + v_parameterization, + sdxl, logging_dir, train_data_dir, reg_data_dir, @@ -281,7 +283,8 @@ def train_model( print_only, pretrained_model_name_or_path, v2, - v_parameterization, sdxl, + v_parameterization, + sdxl, logging_dir, train_data_dir, reg_data_dir, @@ -374,7 +377,7 @@ def train_model( msg='Image folder does not exist', headless=headless_bool ) return - + if not verify_image_folder_pattern(train_data_dir): return @@ -385,7 +388,7 @@ def train_model( headless=headless_bool, ) return - + if not verify_image_folder_pattern(reg_data_dir): return @@ -399,9 +402,12 @@ def train_model( output_name, output_dir, save_model_as, headless=headless_bool ): return - + if sdxl: - output_message(msg='TI training is not compatible with an SDXL model.', headless=headless_bool) + output_message( + msg='TI training is not compatible with an SDXL model.', + headless=headless_bool, + ) return if optimizer == 'Adafactor' and lr_warmup != '0': diff --git a/finetune_gui.py b/finetune_gui.py index 6eab933..2d5d24a 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -51,7 +51,8 @@ def save_configuration( file_path, pretrained_model_name_or_path, v2, - v_parameterization, sdxl, + v_parameterization, + sdxl, train_dir, image_folder, output_dir, @@ -174,7 +175,8 @@ def open_configuration( file_path, pretrained_model_name_or_path, v2, - v_parameterization, sdxl, + v_parameterization, + sdxl, train_dir, image_folder, output_dir, @@ -285,7 +287,8 @@ def train_model( print_only, pretrained_model_name_or_path, v2, - v_parameterization, sdxl, + v_parameterization, + sdxl, train_dir, image_folder, output_dir, @@ -478,7 +481,7 @@ def train_model( run_cmd += f' "./sdxl_train.py"' else: run_cmd += f' "./fine_tune.py"' - + if v2: run_cmd += ' --v2' if v_parameterization: diff --git a/kohya_gui.py b/kohya_gui.py index c2e095e..5c69951 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -30,14 +30,14 @@ def UI(**kwargs): with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: log.info('Load CSS...') css += file.read() + '\n' - + if os.path.exists('./.release'): with open(os.path.join('./.release'), 'r', encoding='utf8') as file: - release= file.read() - + release = file.read() + if os.path.exists('./README.md'): with open(os.path.join('./README.md'), 'r', encoding='utf8') as file: - README= file.read() + README = file.read() interface = gr.Blocks( css=css, title=f'Kohya_ss GUI {release}', theme=gr.themes.Default() @@ -73,18 +73,18 @@ def UI(**kwargs): gradio_merge_lora_tab(headless=headless) gradio_merge_lycoris_tab(headless=headless) gradio_resize_lora_tab(headless=headless) - with gr.Tab('About'): + with gr.Tab('About'): gr.Markdown(f'kohya_ss GUI release {release}') with gr.Tab('README'): gr.Markdown(README) - - htmlStr = f''' + + htmlStr = f"""
{release}
- ''' + """ gr.HTML(htmlStr) # Show the interface launch_kwargs = {} diff --git a/library/common_gui.py b/library/common_gui.py index 9e135ec..9dd5c8b 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -35,8 +35,13 @@ V1_MODELS = [ 'runwayml/stable-diffusion-v1-5', ] +# define a list of substrings to search for SDXL base models +SDXL_MODELS = [ + 'stabilityai/stable-diffusion-SDXL-base', +] + # define a list of substrings to search for -ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS +ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + SDXL_MODELS ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID'] @@ -479,7 +484,15 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name): def set_pretrained_model_name_or_path_input( model_list, pretrained_model_name_or_path, v2, v_parameterization, sdxl ): - # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list + # Check if the given model_list is in the list of SDXL models + if str(model_list) in SDXL_MODELS: + log.info('SDXL model detected. Setting parameters') + v2 = True + v_parameterization = True + sdxl = True + pretrained_model_name_or_path = str(model_list) + + # Check if the given model_list is in the list of V2 base models if str(model_list) in V2_BASE_MODELS: log.info('SD v2 model detected. Setting --v2 parameter') v2 = True @@ -487,7 +500,7 @@ def set_pretrained_model_name_or_path_input( sdxl = False pretrained_model_name_or_path = str(model_list) - # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list + # Check if the given model_list is in the list of V parameterization models if str(model_list) in V_PARAMETERIZATION_MODELS: log.info( 'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization' @@ -497,13 +510,16 @@ def set_pretrained_model_name_or_path_input( sdxl = False pretrained_model_name_or_path = str(model_list) + # Check if the given model_list is in the list of V1 models if str(model_list) in V1_MODELS: v2 = False v_parameterization = False sdxl = False pretrained_model_name_or_path = str(model_list) + # Check if the model_list is set to 'custom' if model_list == 'custom': + # Check if the pretrained_model_name_or_path is in any of the model lists if ( str(pretrained_model_name_or_path) in V1_MODELS or str(pretrained_model_name_or_path) in V2_BASE_MODELS @@ -513,25 +529,58 @@ def set_pretrained_model_name_or_path_input( v2 = False v_parameterization = False sdxl = False - return model_list, pretrained_model_name_or_path, v2, sdxl + + # Return the updated variables + return model_list, pretrained_model_name_or_path, v2, v_parameterization, sdxl -def set_v2_checkbox(model_list, v2, v_parameterization): - # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list +def set_v2_checkbox(model_list, v2, v_parameterization, sdxl): + + if str(model_list) in SDXL_MODELS: + if not v2: + log.info(f'v2 can\'t be deselected because this {str(model_list)} is considered a v2 model...') + v2 = True + if not v_parameterization: + log.info(f'v_parameterization can\'t be deselected because {str(model_list)} require v parameterization...') + v_parameterization = True + if not sdxl: + log.info(f'sdxl can\'t be deselected because {str(model_list)} is an sdxl model...') + sdxl = True + if str(model_list) in V2_BASE_MODELS: - v2 = True - v_parameterization = False + if not v2: + log.info(f'v2 can\'t be deselected because this {str(model_list)} is a v2 model...') + v2 = True + if v_parameterization: + log.info(f'v_parameterization can\'t be selected because {str(model_list)} does not support v parameterization...') + v_parameterization = False + if sdxl: + log.info(f'sdxl can\'t be selected because {str(model_list)} is not an sdxl model...') + sdxl = False - # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list if str(model_list) in V_PARAMETERIZATION_MODELS: - v2 = True - v_parameterization = True + if not v2: + log.info(f'v2 can\'t be deselected because this {str(model_list)} supports v parameterization...') + v2 = True + if not v_parameterization: + log.info(f'v_parameterization can\'t be deselected because {str(model_list)} supports v parameterization...') + v_parameterization = True + if sdxl: + log.info(f'sdxl can\'t be selected because {str(model_list)} is not an sdxl model...') + sdxl = False if str(model_list) in V1_MODELS: - v2 = False - v_parameterization = False + if v2: + log.info(f'v2 can\'t be selected because this {str(model_list)} is a v1 model...') + v2 = False + if v_parameterization: + log.info(f'v_parameterization can\'t be selected because {str(model_list)} is a v1 model...') + v_parameterization = False + if sdxl: + log.info(f'sdxl can\'t be selected because {str(model_list)} is not an sdxl model...') + sdxl = False - return v2, v_parameterization + return v2, v_parameterization, sdxl def set_model_list( @@ -665,14 +714,20 @@ def gradio_source_model( ) v2.change( set_v2_checkbox, - inputs=[model_list, v2, v_parameterization], - outputs=[v2, v_parameterization], + inputs=[model_list, v2, v_parameterization, sdxl], + outputs=[v2, v_parameterization, sdxl], show_progress=False, ) v_parameterization.change( set_v2_checkbox, - inputs=[model_list, v2, v_parameterization], - outputs=[v2, v_parameterization], + inputs=[model_list, v2, v_parameterization, sdxl], + outputs=[v2, v_parameterization, sdxl], + show_progress=False, + ) + sdxl.change( + set_v2_checkbox, + inputs=[model_list, v2, v_parameterization, sdxl], + outputs=[v2, v_parameterization, sdxl], show_progress=False, ) model_list.change( diff --git a/lora_gui.py b/lora_gui.py index 4c80075..3563211 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -69,7 +69,8 @@ def save_configuration( file_path, pretrained_model_name_or_path, v2, - v_parameterization, sdxl, + v_parameterization, + sdxl, logging_dir, train_data_dir, reg_data_dir, @@ -215,7 +216,8 @@ def open_configuration( file_path, pretrained_model_name_or_path, v2, - v_parameterization, sdxl, + v_parameterization, + sdxl, logging_dir, train_data_dir, reg_data_dir, @@ -313,14 +315,14 @@ def open_configuration( network_dropout, rank_dropout, module_dropout, - training_preset + training_preset, ): # Get list of function parameters and values parameters = list(locals().items()) ask_for_file = True if ask_for_file.get('label') == 'True' else False apply_preset = True if apply_preset.get('label') == 'True' else False - + # Check if we are "applying" a preset or a config if apply_preset: log.info(f'Applying preset {training_preset}...') @@ -328,10 +330,12 @@ def open_configuration( else: # If not applying a preset, set the `training_preset` field to an empty string # Find the index of the `training_preset` parameter using the `index()` method - training_preset_index = parameters.index(("training_preset", training_preset)) + training_preset_index = parameters.index( + ('training_preset', training_preset) + ) # Update the value of `training_preset` by directly assigning an empty string value - parameters[training_preset_index] = ("training_preset", "") + parameters[training_preset_index] = ('training_preset', '') original_file_path = file_path @@ -376,7 +380,8 @@ def train_model( print_only, pretrained_model_name_or_path, v2, - v_parameterization, sdxl, + v_parameterization, + sdxl, logging_dir, train_data_dir, reg_data_dir, @@ -496,7 +501,7 @@ def train_model( msg='Image folder does not exist', headless=headless_bool ) return - + if not verify_image_folder_pattern(train_data_dir): return @@ -507,7 +512,7 @@ def train_model( headless=headless_bool, ) return - + if not verify_image_folder_pattern(reg_data_dir): return @@ -1095,15 +1100,18 @@ def lora_tab( outputs=[logging_dir], ) with gr.Tab('Training parameters'): + def list_presets(path): json_files = [] for file in os.listdir(path): - if file.endswith(".json"): + if file.endswith('.json'): json_files.append(os.path.splitext(file)[0]) return json_files + training_preset = gr.Dropdown( label='Presets', - choices=list_presets('./presets/lora'), elem_id="myDropdown" + choices=list_presets('./presets/lora'), + elem_id='myDropdown', ) with gr.Row(): LoRA_type = gr.Dropdown( @@ -1695,21 +1703,31 @@ def lora_tab( button_open_config.click( open_configuration, - inputs=[dummy_db_true, dummy_db_false, config_file_name] + settings_list + [training_preset], - outputs=[config_file_name] + settings_list + [training_preset, LoCon_row], + inputs=[dummy_db_true, dummy_db_false, config_file_name] + + settings_list + + [training_preset], + outputs=[config_file_name] + + settings_list + + [training_preset, LoCon_row], show_progress=False, ) button_load_config.click( open_configuration, - inputs=[dummy_db_false, dummy_db_false, config_file_name] + settings_list + [training_preset], - outputs=[config_file_name] + settings_list + [training_preset, LoCon_row], + inputs=[dummy_db_false, dummy_db_false, config_file_name] + + settings_list + + [training_preset], + outputs=[config_file_name] + + settings_list + + [training_preset, LoCon_row], show_progress=False, ) - + training_preset.input( open_configuration, - inputs=[dummy_db_false, dummy_db_true, config_file_name] + settings_list + [training_preset], + inputs=[dummy_db_false, dummy_db_true, config_file_name] + + settings_list + + [training_preset], outputs=[gr.Textbox()] + settings_list + [training_preset, LoCon_row], show_progress=False, ) diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index 13622df..5e0e89c 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -59,7 +59,8 @@ def save_configuration( file_path, pretrained_model_name_or_path, v2, - v_parameterization, sdxl, + v_parameterization, + sdxl, logging_dir, train_data_dir, reg_data_dir, @@ -181,7 +182,8 @@ def open_configuration( file_path, pretrained_model_name_or_path, v2, - v_parameterization, sdxl, + v_parameterization, + sdxl, logging_dir, train_data_dir, reg_data_dir, @@ -291,7 +293,8 @@ def train_model( print_only, pretrained_model_name_or_path, v2, - v_parameterization, sdxl, + v_parameterization, + sdxl, logging_dir, train_data_dir, reg_data_dir, @@ -389,7 +392,7 @@ def train_model( msg='Image folder does not exist', headless=headless_bool ) return - + if not verify_image_folder_pattern(train_data_dir): return @@ -400,7 +403,7 @@ def train_model( headless=headless_bool, ) return - + if not verify_image_folder_pattern(reg_data_dir): return @@ -425,9 +428,12 @@ def train_model( output_name, output_dir, save_model_as, headless_bool ): return - + if sdxl: - output_message(msg='TI training is not compatible with an SDXL model.', headless=headless_bool) + output_message( + msg='TI training is not compatible with an SDXL model.', + headless=headless_bool, + ) return # if float(noise_offset) > 0 and (