diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 02661b5..6d5664d 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -551,14 +551,15 @@ def train_model( # run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"' run_cmd = "accelerate launch" - + run_cmd += run_cmd_advanced_training( num_processes=num_processes, num_machines=num_machines, multi_gpu=multi_gpu, gpu_ids=gpu_ids, - num_cpu_threads_per_process=num_cpu_threads_per_process) - + num_cpu_threads_per_process=num_cpu_threads_per_process, + ) + if sdxl: run_cmd += f' "./sdxl_train.py"' else: diff --git a/finetune_gui.py b/finetune_gui.py index 9ccd853..73e4b23 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -43,12 +43,12 @@ executor = CommandExecutor() # from easygui import msgbox -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +folder_symbol = "\U0001f4c2" # 📂 +refresh_symbol = "\U0001f504" # 🔄 +save_style_symbol = "\U0001f4be" # 💾 +document_symbol = "\U0001F4C4" # 📄 -PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' +PYTHON = "python3" if os.name == "posix" else "./venv/Scripts/python.exe" def save_configuration( @@ -100,7 +100,8 @@ def save_configuration( save_state, resume, gradient_checkpointing, - gradient_accumulation_steps,block_lr, + gradient_accumulation_steps, + block_lr, mem_eff_attn, shuffle_caption, output_name, @@ -153,19 +154,19 @@ def save_configuration( original_file_path = file_path - save_as_bool = True if save_as.get('label') == 'True' else False + save_as_bool = True if save_as.get("label") == "True" else False if save_as_bool: - log.info('Save as...') + log.info("Save as...") file_path = get_saveasfile_path(file_path) else: - log.info('Save...') - if file_path == None or file_path == '': + log.info("Save...") + if file_path == None or file_path == "": file_path = get_saveasfile_path(file_path) # log.info(file_path) - if file_path == None or file_path == '': + if file_path == None or file_path == "": return original_file_path # In case a file_path was provided and the user decide to cancel the open action # Extract the destination directory from the file path @@ -178,7 +179,7 @@ def save_configuration( SaveConfigFile( parameters=parameters, file_path=file_path, - exclusion=['file_path', 'save_as'], + exclusion=["file_path", "save_as"], ) return file_path @@ -234,7 +235,8 @@ def open_configuration( save_state, resume, gradient_checkpointing, - gradient_accumulation_steps,block_lr, + gradient_accumulation_steps, + block_lr, mem_eff_attn, shuffle_caption, output_name, @@ -286,33 +288,31 @@ def open_configuration( # Get list of function parameters and values parameters = list(locals().items()) - ask_for_file = True if ask_for_file.get('label') == 'True' else False - apply_preset = True if apply_preset.get('label') == 'True' else False + ask_for_file = True if ask_for_file.get("label") == "True" else False + apply_preset = True if apply_preset.get("label") == "True" else False # Check if we are "applying" a preset or a config if apply_preset: - log.info(f'Applying preset {training_preset}...') - file_path = f'./presets/finetune/{training_preset}.json' + log.info(f"Applying preset {training_preset}...") + file_path = f"./presets/finetune/{training_preset}.json" else: # If not applying a preset, set the `training_preset` field to an empty string # Find the index of the `training_preset` parameter using the `index()` method - training_preset_index = parameters.index( - ('training_preset', training_preset) - ) + training_preset_index = parameters.index(("training_preset", training_preset)) # Update the value of `training_preset` by directly assigning an empty string value - parameters[training_preset_index] = ('training_preset', '') + parameters[training_preset_index] = ("training_preset", "") original_file_path = file_path if ask_for_file: file_path = get_file_path(file_path) - if not file_path == '' and not file_path == None: + if not file_path == "" and not file_path == None: # load variables from JSON file - with open(file_path, 'r') as f: + with open(file_path, "r") as f: my_data = json.load(f) - log.info('Loading config...') + log.info("Loading config...") # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True my_data = update_my_data(my_data) else: @@ -323,7 +323,7 @@ def open_configuration( for key, value in parameters: json_value = my_data.get(key) # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ['ask_for_file', 'apply_preset', 'file_path']: + if not key in ["ask_for_file", "apply_preset", "file_path"]: values.append(json_value if json_value is not None else value) return tuple(values) @@ -377,7 +377,8 @@ def train_model( save_state, resume, gradient_checkpointing, - gradient_accumulation_steps,block_lr, + gradient_accumulation_steps, + block_lr, mem_eff_attn, shuffle_caption, output_name, @@ -428,14 +429,12 @@ def train_model( # Get list of function parameters and values parameters = list(locals().items()) - print_only_bool = True if print_only.get('label') == 'True' else False - log.info(f'Start Finetuning...') + print_only_bool = True if print_only.get("label") == "True" else False + log.info(f"Start Finetuning...") - headless_bool = True if headless.get('label') == 'True' else False + headless_bool = True if headless.get("label") == "True" else False - if check_if_model_exist( - output_name, output_dir, save_model_as, headless_bool - ): + if check_if_model_exist(output_name, output_dir, save_model_as, headless_bool): return # if float(noise_offset) > 0 and ( @@ -461,52 +460,50 @@ def train_model( if not os.path.exists(train_dir): os.mkdir(train_dir) - run_cmd = f'{PYTHON} finetune/merge_captions_to_metadata.py' - if caption_extension == '': + run_cmd = f"{PYTHON} finetune/merge_captions_to_metadata.py" + if caption_extension == "": run_cmd += f' --caption_extension=".caption"' else: - run_cmd += f' --caption_extension={caption_extension}' + run_cmd += f" --caption_extension={caption_extension}" run_cmd += f' "{image_folder}"' run_cmd += f' "{train_dir}/{caption_metadata_filename}"' if full_path: - run_cmd += f' --full_path' + run_cmd += f" --full_path" log.info(run_cmd) if not print_only_bool: # Run the command - if os.name == 'posix': + if os.name == "posix": os.system(run_cmd) else: subprocess.run(run_cmd) # create images buckets if generate_image_buckets: - run_cmd = f'{PYTHON} finetune/prepare_buckets_latents.py' + run_cmd = f"{PYTHON} finetune/prepare_buckets_latents.py" run_cmd += f' "{image_folder}"' run_cmd += f' "{train_dir}/{caption_metadata_filename}"' run_cmd += f' "{train_dir}/{latent_metadata_filename}"' run_cmd += f' "{pretrained_model_name_or_path}"' - run_cmd += f' --batch_size={batch_size}' - run_cmd += f' --max_resolution={max_resolution}' - run_cmd += f' --min_bucket_reso={min_bucket_reso}' - run_cmd += f' --max_bucket_reso={max_bucket_reso}' - run_cmd += f' --mixed_precision={mixed_precision}' + run_cmd += f" --batch_size={batch_size}" + run_cmd += f" --max_resolution={max_resolution}" + run_cmd += f" --min_bucket_reso={min_bucket_reso}" + run_cmd += f" --max_bucket_reso={max_bucket_reso}" + run_cmd += f" --mixed_precision={mixed_precision}" # if flip_aug: # run_cmd += f' --flip_aug' if full_path: - run_cmd += f' --full_path' + run_cmd += f" --full_path" if sdxl_checkbox and sdxl_no_half_vae: - log.info( - 'Using mixed_precision = no because no half vae is selected...' - ) + log.info("Using mixed_precision = no because no half vae is selected...") run_cmd += f' --mixed_precision="no"' log.info(run_cmd) if not print_only_bool: # Run the command - if os.name == 'posix': + if os.name == "posix": os.system(run_cmd) else: subprocess.run(run_cmd) @@ -517,13 +514,13 @@ def train_model( for f, lower_f in ( (file, file.lower()) for file in os.listdir(image_folder) ) - if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp')) + if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp")) ] ) - log.info(f'image_num = {image_num}') + log.info(f"image_num = {image_num}") repeats = int(image_num) * int(dataset_repeats) - log.info(f'repeats = {str(repeats)}') + log.info(f"repeats = {str(repeats)}") # calculate max_train_steps max_train_steps = int( @@ -539,26 +536,31 @@ def train_model( if flip_aug: max_train_steps = int(math.ceil(float(max_train_steps) / 2)) - log.info(f'max_train_steps = {max_train_steps}') + log.info(f"max_train_steps = {max_train_steps}") lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) - log.info(f'lr_warmup_steps = {lr_warmup_steps}') + log.info(f"lr_warmup_steps = {lr_warmup_steps}") run_cmd = "accelerate launch" - + run_cmd += run_cmd_advanced_training( num_processes=num_processes, num_machines=num_machines, multi_gpu=multi_gpu, gpu_ids=gpu_ids, - num_cpu_threads_per_process=num_cpu_threads_per_process) - + num_cpu_threads_per_process=num_cpu_threads_per_process, + ) + if sdxl_checkbox: run_cmd += f' "./sdxl_train.py"' else: run_cmd += f' "./fine_tune.py"' - - in_json = f'{train_dir}/{latent_metadata_filename}' if use_latent_files == 'Yes' else f'{train_dir}/{caption_metadata_filename}' + + in_json = ( + f"{train_dir}/{latent_metadata_filename}" + if use_latent_files == "Yes" + else f"{train_dir}/{caption_metadata_filename}" + ) cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs no_half_vae = sdxl_checkbox and sdxl_no_half_vae @@ -570,7 +572,9 @@ def train_model( bucket_reso_steps=bucket_reso_steps, cache_latents=cache_latents, cache_latents_to_disk=cache_latents_to_disk, - cache_text_encoder_outputs=cache_text_encoder_outputs if sdxl_checkbox else None, + cache_text_encoder_outputs=cache_text_encoder_outputs + if sdxl_checkbox + else None, caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, caption_dropout_rate=caption_dropout_rate, caption_extension=caption_extension, @@ -651,7 +655,7 @@ def train_model( if print_only_bool: log.warning( - 'Here is the trainer command as a reference. It will not be executed:\n' + "Here is the trainer command as a reference. It will not be executed:\n" ) print(run_cmd) @@ -659,17 +663,15 @@ def train_model( else: # Saving config file for model current_datetime = datetime.now() - formatted_datetime = current_datetime.strftime('%Y%m%d-%H%M%S') - file_path = os.path.join( - output_dir, f'{output_name}_{formatted_datetime}.json' - ) + formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") + file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json") - log.info(f'Saving training config to {file_path}...') + log.info(f"Saving training config to {file_path}...") SaveConfigFile( parameters=parameters, file_path=file_path, - exclusion=['file_path', 'save_as', 'headless', 'print_only'], + exclusion=["file_path", "save_as", "headless", "print_only"], ) log.info(run_cmd) @@ -678,18 +680,16 @@ def train_model( executor.execute_command(run_cmd=run_cmd) # check if output_dir/last is a folder... therefore it is a diffuser model - last_dir = pathlib.Path(f'{output_dir}/{output_name}') + last_dir = pathlib.Path(f"{output_dir}/{output_name}") if not last_dir.is_dir(): # Copy inference model for v2 if required - save_inference_file( - output_dir, v2, v_parameterization, output_name - ) + save_inference_file(output_dir, v2, v_parameterization, output_name) def remove_doublequote(file_path): if file_path != None: - file_path = file_path.replace('"', '') + file_path = file_path.replace('"', "") return file_path @@ -698,23 +698,23 @@ def finetune_tab(headless=False): dummy_db_true = gr.Label(value=True, visible=False) dummy_db_false = gr.Label(value=False, visible=False) dummy_headless = gr.Label(value=headless, visible=False) - with gr.Tab('Training'): - gr.Markdown('Train a custom model using kohya finetune python code...') + with gr.Tab("Training"): + gr.Markdown("Train a custom model using kohya finetune python code...") # Setup Configuration Files Gradio config = ConfigurationFile(headless) source_model = SourceModel(headless=headless) - with gr.Tab('Folders'): + with gr.Tab("Folders"): with gr.Row(): train_dir = gr.Textbox( - label='Training config folder', - placeholder='folder where the training configuration files will be saved', + label="Training config folder", + placeholder="folder where the training configuration files will be saved", ) train_dir_folder = gr.Button( folder_symbol, - elem_id='open_folder_small', + elem_id="open_folder_small", visible=(not headless), ) train_dir_folder.click( @@ -724,12 +724,12 @@ def finetune_tab(headless=False): ) image_folder = gr.Textbox( - label='Training Image folder', - placeholder='folder where the training images are located', + label="Training Image folder", + placeholder="folder where the training images are located", ) image_folder_input_folder = gr.Button( folder_symbol, - elem_id='open_folder_small', + elem_id="open_folder_small", visible=(not headless), ) image_folder_input_folder.click( @@ -739,12 +739,12 @@ def finetune_tab(headless=False): ) with gr.Row(): output_dir = gr.Textbox( - label='Model output folder', - placeholder='folder where the model will be saved', + label="Model output folder", + placeholder="folder where the model will be saved", ) output_dir_input_folder = gr.Button( folder_symbol, - elem_id='open_folder_small', + elem_id="open_folder_small", visible=(not headless), ) output_dir_input_folder.click( @@ -754,12 +754,12 @@ def finetune_tab(headless=False): ) logging_dir = gr.Textbox( - label='Logging folder', - placeholder='Optional: enable logging and output TensorBoard log to this folder', + label="Logging folder", + placeholder="Optional: enable logging and output TensorBoard log to this folder", ) logging_dir_input_folder = gr.Button( folder_symbol, - elem_id='open_folder_small', + elem_id="open_folder_small", visible=(not headless), ) logging_dir_input_folder.click( @@ -769,9 +769,9 @@ def finetune_tab(headless=False): ) with gr.Row(): output_name = gr.Textbox( - label='Model output name', - placeholder='Name of the model to output', - value='last', + label="Model output name", + placeholder="Name of the model to output", + value="last", interactive=True, ) train_dir.change( @@ -789,102 +789,96 @@ def finetune_tab(headless=False): inputs=[output_dir], outputs=[output_dir], ) - with gr.Tab('Dataset preparation'): + with gr.Tab("Dataset preparation"): with gr.Row(): max_resolution = gr.Textbox( - label='Resolution (width,height)', value='512,512' - ) - min_bucket_reso = gr.Textbox( - label='Min bucket resolution', value='256' + label="Resolution (width,height)", value="512,512" ) + min_bucket_reso = gr.Textbox(label="Min bucket resolution", value="256") max_bucket_reso = gr.Textbox( - label='Max bucket resolution', value='1024' + label="Max bucket resolution", value="1024" ) - batch_size = gr.Textbox(label='Batch size', value='1') + batch_size = gr.Textbox(label="Batch size", value="1") with gr.Row(): create_caption = gr.Checkbox( - label='Generate caption metadata', value=True + label="Generate caption metadata", value=True ) create_buckets = gr.Checkbox( - label='Generate image buckets metadata', value=True + label="Generate image buckets metadata", value=True ) use_latent_files = gr.Dropdown( - label='Use latent files', + label="Use latent files", choices=[ - 'No', - 'Yes', + "No", + "Yes", ], - value='Yes', + value="Yes", ) - with gr.Accordion('Advanced parameters', open=False): + with gr.Accordion("Advanced parameters", open=False): with gr.Row(): caption_metadata_filename = gr.Textbox( - label='Caption metadata filename', - value='meta_cap.json', + label="Caption metadata filename", + value="meta_cap.json", ) latent_metadata_filename = gr.Textbox( - label='Latent metadata filename', value='meta_lat.json' + label="Latent metadata filename", value="meta_lat.json" ) with gr.Row(): - full_path = gr.Checkbox(label='Use full path', value=True) + full_path = gr.Checkbox(label="Use full path", value=True) weighted_captions = gr.Checkbox( - label='Weighted captions', value=False + label="Weighted captions", value=False ) - with gr.Tab('Parameters'): - + with gr.Tab("Parameters"): + def list_presets(path): json_files = [] for file in os.listdir(path): - if file.endswith('.json'): + if file.endswith(".json"): json_files.append(os.path.splitext(file)[0]) - user_presets_path = os.path.join(path, 'user_presets') + user_presets_path = os.path.join(path, "user_presets") if os.path.isdir(user_presets_path): for file in os.listdir(user_presets_path): - if file.endswith('.json'): + if file.endswith(".json"): preset_name = os.path.splitext(file)[0] - json_files.append( - os.path.join('user_presets', preset_name) - ) + json_files.append(os.path.join("user_presets", preset_name)) return json_files training_preset = gr.Dropdown( - label='Presets', - choices=list_presets('./presets/finetune'), - elem_id='myDropdown', + label="Presets", + choices=list_presets("./presets/finetune"), + elem_id="myDropdown", ) - - with gr.Tab('Basic', elem_id='basic_tab'): + + with gr.Tab("Basic", elem_id="basic_tab"): basic_training = BasicTraining( - learning_rate_value='1e-5', finetuning=True, sdxl_checkbox=source_model.sdxl_checkbox, + learning_rate_value="1e-5", + finetuning=True, + sdxl_checkbox=source_model.sdxl_checkbox, ) # Add SDXL Parameters sdxl_params = SDXLParameters(source_model.sdxl_checkbox) with gr.Row(): - dataset_repeats = gr.Textbox( - label='Dataset repeats', value=40 - ) + dataset_repeats = gr.Textbox(label="Dataset repeats", value=40) train_text_encoder = gr.Checkbox( - label='Train text encoder', value=True + label="Train text encoder", value=True ) - with gr.Tab('Advanced', elem_id='advanced_tab'): + with gr.Tab("Advanced", elem_id="advanced_tab"): with gr.Row(): gradient_accumulation_steps = gr.Number( - label='Gradient accumulate steps', value='1' + label="Gradient accumulate steps", value="1" ) block_lr = gr.Textbox( - label='Block LR', - placeholder='(Optional)', - info='Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3', + label="Block LR", + placeholder="(Optional)", + info="Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3", ) - advanced_training = AdvancedTraining( - headless=headless, finetuning=True - ) + advanced_training = AdvancedTraining(headless=headless, finetuning=True) advanced_training.color_aug.change( color_aug_changed, inputs=[advanced_training.color_aug], @@ -893,15 +887,15 @@ def finetune_tab(headless=False): ], # Not applicable to fine_tune.py ) - with gr.Tab('Samples', elem_id='samples_tab'): + with gr.Tab("Samples", elem_id="samples_tab"): sample = SampleImages() with gr.Row(): - button_run = gr.Button('Start training', variant='primary') + button_run = gr.Button("Start training", variant="primary") - button_stop_training = gr.Button('Stop training') + button_stop_training = gr.Button("Stop training") - button_print = gr.Button('Print training command') + button_print = gr.Button("Print training command") # Setup gradio tensorboard buttons ( @@ -1020,9 +1014,7 @@ def finetune_tab(headless=False): inputs=[dummy_db_true, dummy_db_false, config.config_file_name] + settings_list + [training_preset], - outputs=[config.config_file_name] - + settings_list - + [training_preset], + outputs=[config.config_file_name] + settings_list + [training_preset], show_progress=False, ) @@ -1038,9 +1030,7 @@ def finetune_tab(headless=False): inputs=[dummy_db_false, dummy_db_false, config.config_file_name] + settings_list + [training_preset], - outputs=[config.config_file_name] - + settings_list - + [training_preset], + outputs=[config.config_file_name] + settings_list + [training_preset], show_progress=False, ) @@ -1056,9 +1046,7 @@ def finetune_tab(headless=False): inputs=[dummy_db_false, dummy_db_true, config.config_file_name] + settings_list + [training_preset], - outputs=[gr.Textbox()] - + settings_list - + [training_preset], + outputs=[gr.Textbox()] + settings_list + [training_preset], show_progress=False, ) @@ -1090,94 +1078,84 @@ def finetune_tab(headless=False): show_progress=False, ) - with gr.Tab('Guides'): - gr.Markdown( - 'This section provide Various Finetuning guides and information...' - ) - top_level_path = './docs/Finetuning/top_level.md' + with gr.Tab("Guides"): + gr.Markdown("This section provide Various Finetuning guides and information...") + top_level_path = "./docs/Finetuning/top_level.md" if os.path.exists(top_level_path): - with open( - os.path.join(top_level_path), 'r', encoding='utf8' - ) as file: - guides_top_level = file.read() + '\n' + with open(os.path.join(top_level_path), "r", encoding="utf8") as file: + guides_top_level = file.read() + "\n" gr.Markdown(guides_top_level) def UI(**kwargs): - add_javascript(kwargs.get('language')) - css = '' + add_javascript(kwargs.get("language")) + css = "" - headless = kwargs.get('headless', False) - log.info(f'headless: {headless}') + headless = kwargs.get("headless", False) + log.info(f"headless: {headless}") - if os.path.exists('./style.css'): - with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: - log.info('Load CSS...') - css += file.read() + '\n' + if os.path.exists("./style.css"): + with open(os.path.join("./style.css"), "r", encoding="utf8") as file: + log.info("Load CSS...") + css += file.read() + "\n" - interface = gr.Blocks( - css=css, title='Kohya_ss GUI', theme=gr.themes.Default() - ) + interface = gr.Blocks(css=css, title="Kohya_ss GUI", theme=gr.themes.Default()) with interface: - with gr.Tab('Finetune'): + with gr.Tab("Finetune"): finetune_tab(headless=headless) - with gr.Tab('Utilities'): + with gr.Tab("Utilities"): utilities_tab(enable_dreambooth_tab=False, headless=headless) # Show the interface launch_kwargs = {} - username = kwargs.get('username') - password = kwargs.get('password') - server_port = kwargs.get('server_port', 0) - inbrowser = kwargs.get('inbrowser', False) - share = kwargs.get('share', False) - server_name = kwargs.get('listen') + username = kwargs.get("username") + password = kwargs.get("password") + server_port = kwargs.get("server_port", 0) + inbrowser = kwargs.get("inbrowser", False) + share = kwargs.get("share", False) + server_name = kwargs.get("listen") - launch_kwargs['server_name'] = server_name + launch_kwargs["server_name"] = server_name if username and password: - launch_kwargs['auth'] = (username, password) + launch_kwargs["auth"] = (username, password) if server_port > 0: - launch_kwargs['server_port'] = server_port + launch_kwargs["server_port"] = server_port if inbrowser: - launch_kwargs['inbrowser'] = inbrowser + launch_kwargs["inbrowser"] = inbrowser if share: - launch_kwargs['share'] = share + launch_kwargs["share"] = share interface.launch(**launch_kwargs) -if __name__ == '__main__': +if __name__ == "__main__": # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() parser.add_argument( - '--listen', + "--listen", type=str, - default='127.0.0.1', - help='IP to listen on for connections to Gradio', + default="127.0.0.1", + help="IP to listen on for connections to Gradio", ) parser.add_argument( - '--username', type=str, default='', help='Username for authentication' + "--username", type=str, default="", help="Username for authentication" ) parser.add_argument( - '--password', type=str, default='', help='Password for authentication' + "--password", type=str, default="", help="Password for authentication" ) parser.add_argument( - '--server_port', + "--server_port", type=int, default=0, - help='Port to run the server listener on', + help="Port to run the server listener on", + ) + parser.add_argument("--inbrowser", action="store_true", help="Open in browser") + parser.add_argument("--share", action="store_true", help="Share the gradio UI") + parser.add_argument( + "--headless", action="store_true", help="Is the server headless" ) parser.add_argument( - '--inbrowser', action='store_true', help='Open in browser' - ) - parser.add_argument( - '--share', action='store_true', help='Share the gradio UI' - ) - parser.add_argument( - '--headless', action='store_true', help='Is the server headless' - ) - parser.add_argument( - '--language', type=str, default=None, help='Set custom language' + "--language", type=str, default=None, help="Set custom language" ) args = parser.parse_args() diff --git a/library/common_gui.py b/library/common_gui.py index aa8daba..a0a0c57 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -678,11 +678,11 @@ def get_str_or_default(kwargs, key, default_value=""): def run_cmd_advanced_training(**kwargs): run_cmd = "" - + additional_parameters = kwargs.get("additional_parameters") if additional_parameters: run_cmd += f" {additional_parameters}" - + block_lr = kwargs.get("block_lr") if block_lr: run_cmd += f' --block_lr="(block_lr)"' @@ -702,7 +702,7 @@ def run_cmd_advanced_training(**kwargs): cache_latents_to_disk = kwargs.get("cache_latents_to_disk") if cache_latents_to_disk: run_cmd += " --cache_latents_to_disk" - + cache_text_encoder_outputs = kwargs.get("cache_text_encoder_outputs") if cache_text_encoder_outputs: run_cmd += " --cache_text_encoder_outputs" @@ -728,18 +728,18 @@ def run_cmd_advanced_training(**kwargs): color_aug = kwargs.get("color_aug") if color_aug: run_cmd += " --color_aug" - + dataset_repeats = kwargs.get("dataset_repeats") if dataset_repeats: run_cmd += f' --dataset_repeats="{dataset_repeats}"' - + enable_bucket = kwargs.get("enable_bucket") if enable_bucket: min_bucket_reso = kwargs.get("min_bucket_reso") max_bucket_reso = kwargs.get("max_bucket_reso") if min_bucket_reso and max_bucket_reso: run_cmd += f" --enable_bucket --min_bucket_reso={min_bucket_reso} --max_bucket_reso={max_bucket_reso}" - + in_json = kwargs.get("in_json") if in_json: run_cmd += f' --in_json="{in_json}"' @@ -751,7 +751,7 @@ def run_cmd_advanced_training(**kwargs): fp8_base = kwargs.get("fp8_base") if fp8_base: run_cmd += " --fp8_base" - + full_bf16 = kwargs.get("full_bf16") if full_bf16: run_cmd += " --full_bf16" @@ -759,7 +759,7 @@ def run_cmd_advanced_training(**kwargs): full_fp16 = kwargs.get("full_fp16") if full_fp16: run_cmd += " --full_fp16" - + gradient_accumulation_steps = kwargs.get("gradient_accumulation_steps") if gradient_accumulation_steps and int(gradient_accumulation_steps) > 1: run_cmd += f" --gradient_accumulation_steps={int(gradient_accumulation_steps)}" @@ -775,19 +775,19 @@ def run_cmd_advanced_training(**kwargs): learning_rate = kwargs.get("learning_rate") if learning_rate: run_cmd += f' --learning_rate="{learning_rate}"' - + learning_rate_te = kwargs.get("learning_rate_te") if learning_rate_te: run_cmd += f' --learning_rate_te="{learning_rate_te}"' - + learning_rate_te1 = kwargs.get("learning_rate_te1") if learning_rate_te1: run_cmd += f' --learning_rate_te1="{learning_rate_te1}"' - + learning_rate_te2 = kwargs.get("learning_rate_te2") if learning_rate_te2: run_cmd += f' --learning_rate_te2="{learning_rate_te2}"' - + logging_dir = kwargs.get("logging_dir") if logging_dir: run_cmd += f' --logging_dir="{logging_dir}"' @@ -799,7 +799,7 @@ def run_cmd_advanced_training(**kwargs): lr_scheduler_args = kwargs.get("lr_scheduler_args") if lr_scheduler_args and lr_scheduler_args != "": run_cmd += f" --lr_scheduler_args {lr_scheduler_args}" - + lr_scheduler_num_cycles = kwargs.get("lr_scheduler_num_cycles") if lr_scheduler_num_cycles and not lr_scheduler_num_cycles == "": run_cmd += f' --lr_scheduler_num_cycles="{lr_scheduler_num_cycles}"' @@ -807,7 +807,7 @@ def run_cmd_advanced_training(**kwargs): epoch = kwargs.get("epoch") if epoch: run_cmd += f' --lr_scheduler_num_cycles="{epoch}"' - + lr_scheduler_power = kwargs.get("lr_scheduler_power") if lr_scheduler_power and not lr_scheduler_power == "": run_cmd += f' --lr_scheduler_power="{lr_scheduler_power}"' @@ -830,7 +830,7 @@ def run_cmd_advanced_training(**kwargs): max_grad_norm = kwargs.get("max_grad_norm") if max_grad_norm and max_grad_norm != "": run_cmd += f' --max_grad_norm="{max_grad_norm}"' - + max_resolution = kwargs.get("max_resolution") if max_resolution: run_cmd += f' --resolution="{max_resolution}"' @@ -844,7 +844,7 @@ def run_cmd_advanced_training(**kwargs): run_cmd += f" --max_token_length={int(max_token_length)}" max_train_epochs = kwargs.get("max_train_epochs") - if max_train_epochs and not max_train_epochs == '': + if max_train_epochs and not max_train_epochs == "": run_cmd += f" --max_train_epochs={max_train_epochs}" max_train_steps = kwargs.get("max_train_steps") @@ -870,15 +870,15 @@ def run_cmd_advanced_training(**kwargs): multi_gpu = kwargs.get("multi_gpu") if multi_gpu: run_cmd += " --multi_gpu" - + no_half_vae = kwargs.get("no_half_vae") if no_half_vae: run_cmd += " --no_half_vae" - + no_token_padding = kwargs.get("no_token_padding") if no_token_padding: run_cmd += " --no_token_padding" - + noise_offset_type = kwargs.get("noise_offset_type") if noise_offset_type and noise_offset_type == "Original": noise_offset = kwargs.get("noise_offset") @@ -886,17 +886,23 @@ def run_cmd_advanced_training(**kwargs): run_cmd += f" --noise_offset={float(noise_offset)}" adaptive_noise_scale = kwargs.get("adaptive_noise_scale") - if adaptive_noise_scale and float(adaptive_noise_scale) != 0 and float(noise_offset) > 0: + if ( + adaptive_noise_scale + and float(adaptive_noise_scale) != 0 + and float(noise_offset) > 0 + ): run_cmd += f" --adaptive_noise_scale={float(adaptive_noise_scale)}" elif noise_offset_type and noise_offset_type == "Multires": multires_noise_iterations = kwargs.get("multires_noise_iterations") if int(multires_noise_iterations) > 0: - run_cmd += f' --multires_noise_iterations="{int(multires_noise_iterations)}"' + run_cmd += ( + f' --multires_noise_iterations="{int(multires_noise_iterations)}"' + ) multires_noise_discount = kwargs.get("multires_noise_discount") if multires_noise_discount and float(multires_noise_discount) > 0: run_cmd += f' --multires_noise_discount="{float(multires_noise_discount)}"' - + num_machines = kwargs.get("num_machines") if num_machines and int(num_machines) > 1: run_cmd += f" --num_machines={int(num_machines)}" @@ -916,12 +922,12 @@ def run_cmd_advanced_training(**kwargs): optimizer_type = kwargs.get("optimizer") if optimizer_type: run_cmd += f' --optimizer_type="{optimizer_type}"' - + output_dir = kwargs.get("output_dir") if output_dir: run_cmd += f' --output_dir="{output_dir}"' - - output_name = kwargs.get("output_name") + + output_name = kwargs.get("output_name") if output_name and not output_name == "": run_cmd += f' --output_name="{output_name}"' @@ -932,7 +938,7 @@ def run_cmd_advanced_training(**kwargs): pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path") if pretrained_model_name_or_path: run_cmd += f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' - + prior_loss_weight = kwargs.get("prior_loss_weight") if prior_loss_weight and not float(prior_loss_weight) == 1.0: run_cmd += f" --prior_loss_weight={prior_loss_weight}" @@ -964,7 +970,7 @@ def run_cmd_advanced_training(**kwargs): save_last_n_steps_state = kwargs.get("save_last_n_steps_state") if save_last_n_steps_state and int(save_last_n_steps_state) > 0: run_cmd += f' --save_last_n_steps_state="{int(save_last_n_steps_state)}"' - + save_model_as = kwargs.get("save_model_as") if save_model_as and not save_model_as == "same as source model": run_cmd += f" --save_model_as={save_model_as}" @@ -988,7 +994,7 @@ def run_cmd_advanced_training(**kwargs): shuffle_caption = kwargs.get("shuffle_caption") if shuffle_caption: run_cmd += " --shuffle_caption" - + stop_text_encoder_training = kwargs.get("stop_text_encoder_training") if stop_text_encoder_training and stop_text_encoder_training > 0: run_cmd += f' --stop_text_encoder_training="{stop_text_encoder_training}"' @@ -996,11 +1002,11 @@ def run_cmd_advanced_training(**kwargs): train_batch_size = kwargs.get("train_batch_size") if train_batch_size: run_cmd += f' --train_batch_size="{train_batch_size}"' - + train_data_dir = kwargs.get("train_data_dir") if train_data_dir: run_cmd += f' --train_data_dir="{train_data_dir}"' - + train_text_encoder = kwargs.get("train_text_encoder") if train_text_encoder: run_cmd += " --train_text_encoder" @@ -1008,7 +1014,7 @@ def run_cmd_advanced_training(**kwargs): use_wandb = kwargs.get("use_wandb") if use_wandb: run_cmd += " --log_with wandb" - + v_parameterization = kwargs.get("v_parameterization") if v_parameterization: run_cmd += " --v_parameterization" @@ -1016,7 +1022,7 @@ def run_cmd_advanced_training(**kwargs): v_pred_like_loss = kwargs.get("v_pred_like_loss") if v_pred_like_loss and float(v_pred_like_loss) > 0: run_cmd += f' --v_pred_like_loss="{float(v_pred_like_loss)}"' - + v2 = kwargs.get("v2") if v2: run_cmd += " --v2" @@ -1032,7 +1038,7 @@ def run_cmd_advanced_training(**kwargs): wandb_api_key = kwargs.get("wandb_api_key") if wandb_api_key: run_cmd += f' --wandb_api_key="{wandb_api_key}"' - + weighted_captions = kwargs.get("weighted_captions") if weighted_captions: run_cmd += " --weighted_captions" diff --git a/lora_gui.py b/lora_gui.py index d5913d4..4e1ef57 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -732,14 +732,15 @@ def train_model( log.info(f"lr_warmup_steps = {lr_warmup_steps}") run_cmd = "accelerate launch" - + run_cmd += run_cmd_advanced_training( num_processes=num_processes, num_machines=num_machines, multi_gpu=multi_gpu, gpu_ids=gpu_ids, - num_cpu_threads_per_process=num_cpu_threads_per_process) - + num_cpu_threads_per_process=num_cpu_threads_per_process, + ) + if sdxl: run_cmd += f' "./sdxl_train_network.py"' else: @@ -1803,7 +1804,9 @@ def lora_tab( placeholder="(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2", info="Specify the alpha of each block when expanding LoRA to Conv2d 3x3. Specify 25 numbers. If omitted, the value of conv_alpha is used.", ) - advanced_training = AdvancedTraining(headless=headless, training_type="lora") + advanced_training = AdvancedTraining( + headless=headless, training_type="lora" + ) advanced_training.color_aug.change( color_aug_changed, inputs=[advanced_training.color_aug], diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index d54adb8..b609b91 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -153,19 +153,19 @@ def save_configuration( original_file_path = file_path - save_as_bool = True if save_as.get('label') == 'True' else False + save_as_bool = True if save_as.get("label") == "True" else False if save_as_bool: - log.info('Save as...') + log.info("Save as...") file_path = get_saveasfile_path(file_path) else: - log.info('Save...') - if file_path == None or file_path == '': + log.info("Save...") + if file_path == None or file_path == "": file_path = get_saveasfile_path(file_path) # log.info(file_path) - if file_path == None or file_path == '': + if file_path == None or file_path == "": return original_file_path # In case a file_path was provided and the user decide to cancel the open action # Extract the destination directory from the file path @@ -178,7 +178,7 @@ def save_configuration( SaveConfigFile( parameters=parameters, file_path=file_path, - exclusion=['file_path', 'save_as'], + exclusion=["file_path", "save_as"], ) return file_path @@ -282,18 +282,18 @@ def open_configuration( # Get list of function parameters and values parameters = list(locals().items()) - ask_for_file = True if ask_for_file.get('label') == 'True' else False + ask_for_file = True if ask_for_file.get("label") == "True" else False original_file_path = file_path if ask_for_file: file_path = get_file_path(file_path) - if not file_path == '' and not file_path == None: + if not file_path == "" and not file_path == None: # load variables from JSON file - with open(file_path, 'r') as f: + with open(file_path, "r") as f: my_data = json.load(f) - log.info('Loading config...') + log.info("Loading config...") # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True my_data = update_my_data(my_data) else: @@ -303,7 +303,7 @@ def open_configuration( values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ['ask_for_file', 'file_path']: + if not key in ["ask_for_file", "file_path"]: values.append(my_data.get(key, value)) return tuple(values) @@ -406,36 +406,32 @@ def train_model( # Get list of function parameters and values parameters = list(locals().items()) - print_only_bool = True if print_only.get('label') == 'True' else False - log.info(f'Start training TI...') + print_only_bool = True if print_only.get("label") == "True" else False + log.info(f"Start training TI...") - headless_bool = True if headless.get('label') == 'True' else False + headless_bool = True if headless.get("label") == "True" else False - if pretrained_model_name_or_path == '': + if pretrained_model_name_or_path == "": output_message( - msg='Source model information is missing', headless=headless_bool + msg="Source model information is missing", headless=headless_bool ) return - if train_data_dir == '': - output_message( - msg='Image folder path is missing', headless=headless_bool - ) + if train_data_dir == "": + output_message(msg="Image folder path is missing", headless=headless_bool) return if not os.path.exists(train_data_dir): - output_message( - msg='Image folder does not exist', headless=headless_bool - ) + output_message(msg="Image folder does not exist", headless=headless_bool) return if not verify_image_folder_pattern(train_data_dir): return - if reg_data_dir != '': + if reg_data_dir != "": if not os.path.exists(reg_data_dir): output_message( - msg='Regularisation folder does not exist', + msg="Regularisation folder does not exist", headless=headless_bool, ) return @@ -443,26 +439,22 @@ def train_model( if not verify_image_folder_pattern(reg_data_dir): return - if output_dir == '': - output_message( - msg='Output folder path is missing', headless=headless_bool - ) + if output_dir == "": + output_message(msg="Output folder path is missing", headless=headless_bool) return - if token_string == '': - output_message(msg='Token string is missing', headless=headless_bool) + if token_string == "": + output_message(msg="Token string is missing", headless=headless_bool) return - if init_word == '': - output_message(msg='Init word is missing', headless=headless_bool) + if init_word == "": + output_message(msg="Init word is missing", headless=headless_bool) return if not os.path.exists(output_dir): os.makedirs(output_dir) - if check_if_model_exist( - output_name, output_dir, save_model_as, headless_bool - ): + if check_if_model_exist(output_name, output_dir, save_model_as, headless_bool): return # if float(noise_offset) > 0 and ( @@ -495,7 +487,7 @@ def train_model( # Loop through each subfolder and extract the number of repeats for folder in subfolders: # Extract the number of repeats from the folder name - repeats = int(folder.split('_')[0]) + repeats = int(folder.split("_")[0]) # Count the number of images in the folder num_images = len( @@ -503,11 +495,9 @@ def train_model( f for f, lower_f in ( (file, file.lower()) - for file in os.listdir( - os.path.join(train_data_dir, folder) - ) + for file in os.listdir(os.path.join(train_data_dir, folder)) ) - if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp')) + if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp")) ] ) @@ -516,21 +506,21 @@ def train_model( total_steps += steps # Print the result - log.info(f'Folder {folder}: {steps} steps') + log.info(f"Folder {folder}: {steps} steps") # Print the result # log.info(f"{total_steps} total steps") - if reg_data_dir == '': + if reg_data_dir == "": reg_factor = 1 else: log.info( - 'Regularisation images are used... Will double the number of steps required...' + "Regularisation images are used... Will double the number of steps required..." ) reg_factor = 2 # calculate max_train_steps - if max_train_steps == '' or max_train_steps == '0': + if max_train_steps == "" or max_train_steps == "0": max_train_steps = int( math.ceil( float(total_steps) @@ -543,7 +533,7 @@ def train_model( else: max_train_steps = int(max_train_steps) - log.info(f'max_train_steps = {max_train_steps}') + log.info(f"max_train_steps = {max_train_steps}") # calculate stop encoder training if stop_text_encoder_training_pct == None: @@ -552,20 +542,21 @@ def train_model( stop_text_encoder_training = math.ceil( float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) ) - log.info(f'stop_text_encoder_training = {stop_text_encoder_training}') + log.info(f"stop_text_encoder_training = {stop_text_encoder_training}") lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) - log.info(f'lr_warmup_steps = {lr_warmup_steps}') + log.info(f"lr_warmup_steps = {lr_warmup_steps}") run_cmd = "accelerate launch" - + run_cmd += run_cmd_advanced_training( num_processes=num_processes, num_machines=num_machines, multi_gpu=multi_gpu, gpu_ids=gpu_ids, - num_cpu_threads_per_process=num_cpu_threads_per_process) - + num_cpu_threads_per_process=num_cpu_threads_per_process, + ) + if sdxl: run_cmd += f' "./sdxl_train_textual_inversion.py"' else: @@ -649,13 +640,13 @@ def train_model( ) run_cmd += f' --token_string="{token_string}"' run_cmd += f' --init_word="{init_word}"' - run_cmd += f' --num_vectors_per_token={num_vectors_per_token}' - if not weights == '': + run_cmd += f" --num_vectors_per_token={num_vectors_per_token}" + if not weights == "": run_cmd += f' --weights="{weights}"' - if template == 'object template': - run_cmd += f' --use_object_template' - elif template == 'style template': - run_cmd += f' --use_style_template' + if template == "object template": + run_cmd += f" --use_object_template" + elif template == "style template": + run_cmd += f" --use_style_template" run_cmd += run_cmd_sample( sample_every_n_steps, @@ -667,7 +658,7 @@ def train_model( if print_only_bool: log.warning( - 'Here is the trainer command as a reference. It will not be executed:\n' + "Here is the trainer command as a reference. It will not be executed:\n" ) print(run_cmd) @@ -675,17 +666,15 @@ def train_model( else: # Saving config file for model current_datetime = datetime.now() - formatted_datetime = current_datetime.strftime('%Y%m%d-%H%M%S') - file_path = os.path.join( - output_dir, f'{output_name}_{formatted_datetime}.json' - ) + formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") + file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json") - log.info(f'Saving training config to {file_path}...') + log.info(f"Saving training config to {file_path}...") SaveConfigFile( parameters=parameters, file_path=file_path, - exclusion=['file_path', 'save_as', 'headless', 'print_only'], + exclusion=["file_path", "save_as", "headless", "print_only"], ) log.info(run_cmd) @@ -695,13 +684,11 @@ def train_model( executor.execute_command(run_cmd=run_cmd) # check if output_dir/last is a folder... therefore it is a diffuser model - last_dir = pathlib.Path(f'{output_dir}/{output_name}') + last_dir = pathlib.Path(f"{output_dir}/{output_name}") if not last_dir.is_dir(): # Copy inference model for v2 if required - save_inference_file( - output_dir, v2, v_parameterization, output_name - ) + save_inference_file(output_dir, v2, v_parameterization, output_name) def ti_tab( @@ -711,32 +698,32 @@ def ti_tab( dummy_db_false = gr.Label(value=False, visible=False) dummy_headless = gr.Label(value=headless, visible=False) - with gr.Tab('Training'): - gr.Markdown('Train a TI using kohya textual inversion python code...') + with gr.Tab("Training"): + gr.Markdown("Train a TI using kohya textual inversion python code...") # Setup Configuration Files Gradio config = ConfigurationFile(headless) source_model = SourceModel( save_model_as_choices=[ - 'ckpt', - 'safetensors', + "ckpt", + "safetensors", ], headless=headless, ) - with gr.Tab('Folders'): + with gr.Tab("Folders"): folders = Folders(headless=headless) - with gr.Tab('Parameters'): - with gr.Tab('Basic', elem_id='basic_tab'): + with gr.Tab("Parameters"): + with gr.Tab("Basic", elem_id="basic_tab"): with gr.Row(): weights = gr.Textbox( - label='Resume TI training', - placeholder='(Optional) Path to existing TI embeding file to keep training', + label="Resume TI training", + placeholder="(Optional) Path to existing TI embeding file to keep training", ) weights_file_input = gr.Button( - '📂', - elem_id='open_folder_small', + "", + elem_id="open_folder_small", visible=(not headless), ) weights_file_input.click( @@ -746,37 +733,37 @@ def ti_tab( ) with gr.Row(): token_string = gr.Textbox( - label='Token string', - placeholder='eg: cat', + label="Token string", + placeholder="eg: cat", ) init_word = gr.Textbox( - label='Init word', - value='*', + label="Init word", + value="*", ) num_vectors_per_token = gr.Slider( minimum=1, maximum=75, value=1, step=1, - label='Vectors', + label="Vectors", ) # max_train_steps = gr.Textbox( # label='Max train steps', # placeholder='(Optional) Maximum number of steps', # ) template = gr.Dropdown( - label='Template', + label="Template", choices=[ - 'caption', - 'object template', - 'style template', + "caption", + "object template", + "style template", ], - value='caption', + value="caption", ) basic_training = BasicTraining( - learning_rate_value='1e-5', - lr_scheduler_value='cosine', - lr_warmup_value='10', + learning_rate_value="1e-5", + lr_scheduler_value="cosine", + lr_warmup_value="10", sdxl_checkbox=source_model.sdxl_checkbox, ) @@ -786,7 +773,7 @@ def ti_tab( show_sdxl_cache_text_encoder_outputs=False, ) - with gr.Tab('Advanced', elem_id='advanced_tab'): + with gr.Tab("Advanced", elem_id="advanced_tab"): advanced_training = AdvancedTraining(headless=headless) advanced_training.color_aug.change( color_aug_changed, @@ -794,12 +781,12 @@ def ti_tab( outputs=[basic_training.cache_latents], ) - with gr.Tab('Samples', elem_id='samples_tab'): + with gr.Tab("Samples", elem_id="samples_tab"): sample = SampleImages() - with gr.Tab('Dataset Preparation'): + with gr.Tab("Dataset Preparation"): gr.Markdown( - 'This section provide Dreambooth tools to help setup your dataset...' + "This section provide Dreambooth tools to help setup your dataset..." ) gradio_dreambooth_folder_creation_tab( train_data_dir_input=folders.train_data_dir, @@ -811,11 +798,11 @@ def ti_tab( gradio_dataset_balancing_tab(headless=headless) with gr.Row(): - button_run = gr.Button('Start training', variant='primary') + button_run = gr.Button("Start training", variant="primary") - button_stop_training = gr.Button('Stop training') + button_stop_training = gr.Button("Stop training") - button_print = gr.Button('Print training command') + button_print = gr.Button("Print training command") # Setup gradio tensorboard buttons ( @@ -978,30 +965,28 @@ def ti_tab( def UI(**kwargs): - add_javascript(kwargs.get('language')) - css = '' + add_javascript(kwargs.get("language")) + css = "" - headless = kwargs.get('headless', False) - log.info(f'headless: {headless}') + headless = kwargs.get("headless", False) + log.info(f"headless: {headless}") - if os.path.exists('./style.css'): - with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: - log.info('Load CSS...') - css += file.read() + '\n' + if os.path.exists("./style.css"): + with open(os.path.join("./style.css"), "r", encoding="utf8") as file: + log.info("Load CSS...") + css += file.read() + "\n" - interface = gr.Blocks( - css=css, title='Kohya_ss GUI', theme=gr.themes.Default() - ) + interface = gr.Blocks(css=css, title="Kohya_ss GUI", theme=gr.themes.Default()) with interface: - with gr.Tab('Dreambooth TI'): + with gr.Tab("Dreambooth TI"): ( train_data_dir_input, reg_data_dir_input, output_dir_input, logging_dir_input, ) = ti_tab(headless=headless) - with gr.Tab('Utilities'): + with gr.Tab("Utilities"): utilities_tab( train_data_dir_input=train_data_dir_input, reg_data_dir_input=reg_data_dir_input, @@ -1013,57 +998,53 @@ def UI(**kwargs): # Show the interface launch_kwargs = {} - username = kwargs.get('username') - password = kwargs.get('password') - server_port = kwargs.get('server_port', 0) - inbrowser = kwargs.get('inbrowser', False) - share = kwargs.get('share', False) - server_name = kwargs.get('listen') + username = kwargs.get("username") + password = kwargs.get("password") + server_port = kwargs.get("server_port", 0) + inbrowser = kwargs.get("inbrowser", False) + share = kwargs.get("share", False) + server_name = kwargs.get("listen") - launch_kwargs['server_name'] = server_name + launch_kwargs["server_name"] = server_name if username and password: - launch_kwargs['auth'] = (username, password) + launch_kwargs["auth"] = (username, password) if server_port > 0: - launch_kwargs['server_port'] = server_port + launch_kwargs["server_port"] = server_port if inbrowser: - launch_kwargs['inbrowser'] = inbrowser + launch_kwargs["inbrowser"] = inbrowser if share: - launch_kwargs['share'] = share + launch_kwargs["share"] = share interface.launch(**launch_kwargs) -if __name__ == '__main__': +if __name__ == "__main__": # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() parser.add_argument( - '--listen', + "--listen", type=str, - default='127.0.0.1', - help='IP to listen on for connections to Gradio', + default="127.0.0.1", + help="IP to listen on for connections to Gradio", ) parser.add_argument( - '--username', type=str, default='', help='Username for authentication' + "--username", type=str, default="", help="Username for authentication" ) parser.add_argument( - '--password', type=str, default='', help='Password for authentication' + "--password", type=str, default="", help="Password for authentication" ) parser.add_argument( - '--server_port', + "--server_port", type=int, default=0, - help='Port to run the server listener on', + help="Port to run the server listener on", + ) + parser.add_argument("--inbrowser", action="store_true", help="Open in browser") + parser.add_argument("--share", action="store_true", help="Share the gradio UI") + parser.add_argument( + "--headless", action="store_true", help="Is the server headless" ) parser.add_argument( - '--inbrowser', action='store_true', help='Open in browser' - ) - parser.add_argument( - '--share', action='store_true', help='Share the gradio UI' - ) - parser.add_argument( - '--headless', action='store_true', help='Is the server headless' - ) - parser.add_argument( - '--language', type=str, default=None, help='Set custom language' + "--language", type=str, default=None, help="Set custom language" ) args = parser.parse_args()