diff --git a/.gitignore b/.gitignore index 34f666e..fddfbe1 100644 --- a/.gitignore +++ b/.gitignore @@ -51,3 +51,5 @@ outputs dataset/** !dataset/**/ !dataset/**/.gitkeep +# models +# data diff --git a/README.md b/README.md index 87f5400..b31afe6 100644 --- a/README.md +++ b/README.md @@ -373,6 +373,7 @@ The documentation in this section will be moved to a separate document later. - Add support for `wandb_run_name`, `log_tracker_name` and `log_tracker_config` parameters under the advanced section. - Update sd-scripts to v0.8.5 - Improve code +- Add support for custom path defaults. Simply copy the `config example.toml` file found in the root of the repo to `config.toml` and edit the different values to your taste. ### 2024/03/13 (v23.0.11) diff --git a/config example.toml b/config example.toml new file mode 100644 index 0000000..f3f833a --- /dev/null +++ b/config example.toml @@ -0,0 +1,13 @@ +# Copy this file and name it config.toml +# Edit the values to suit your needs + +# Default folders location +models_dir = "./models" # Pretrained model name or path +train_data_dir = "./data" # Image folder (containing training images subfolders) / Image folder (containing training images) +output_dir = "./outputs" # Output directory for trained model +reg_data_dir = "./data/reg" # Regularisation directory +logging_dir = "./logs" # Logging directory +config_dir = "./presets" # Load/Save Config file + +# Example custom folder location +# models_dir = "e:/models" # Pretrained model name or path \ No newline at end of file diff --git a/kohya_gui/class_configuration_file.py b/kohya_gui/class_configuration_file.py index f74a15f..3c62e55 100644 --- a/kohya_gui/class_configuration_file.py +++ b/kohya_gui/class_configuration_file.py @@ -1,7 +1,11 @@ import gradio as gr import os +import toml +from .common_gui import list_files, scriptdir, create_refresh_button, load_kohya_ss_gui_config +from .custom_logging import setup_logging -from .common_gui import list_files, scriptdir, create_refresh_button +# Set up logging +log = setup_logging() class ConfigurationFile: @@ -19,11 +23,11 @@ class ConfigurationFile: """ self.headless = headless + + config = load_kohya_ss_gui_config() # Sets the directory for storing configuration files, defaults to a 'presets' folder within the script directory. - self.current_config_dir = ( - config_dir if config_dir is not None else os.path.join(scriptdir, "presets") - ) + self.current_config_dir = config.get('config_dir', os.path.join(scriptdir, "presets")) # Initialize the GUI components for configuration. self.create_config_gui() @@ -38,7 +42,7 @@ class ConfigurationFile: Returns: - list: A list of directories. """ - self.current_config_dir = path + self.current_config_dir = path if not path == "" else "." # Lists all .json files in the current configuration directory, used for populating dropdown choices. return list(list_files(self.current_config_dir, exts=[".json"], all=True)) diff --git a/kohya_gui/class_folders.py b/kohya_gui/class_folders.py index 6a2a0f6..a21cc3a 100644 --- a/kohya_gui/class_folders.py +++ b/kohya_gui/class_folders.py @@ -1,31 +1,29 @@ import gradio as gr import os -from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button +from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button, load_kohya_ss_gui_config class Folders: """ A class to handle folder operations in the GUI. """ - def __init__(self, finetune: bool = False, data_dir: str = None, output_dir: str = None, logging_dir: str = None, reg_data_dir: str = None, headless: bool = False): + def __init__(self, finetune: bool = False, headless: bool = False): """ Initialize the Folders class. Parameters: - finetune (bool): Whether to finetune the model. - - data_dir (str): The directory for data. - - output_dir (str): The directory for output. - - logging_dir (str): The directory for logging. - - reg_data_dir (str): The directory for regularization data. - headless (bool): Whether to run in headless mode. """ self.headless = headless self.finetune = finetune + + # Load kohya_ss GUI configs from config.toml if it exist + config = load_kohya_ss_gui_config() # Set default directories if not provided - self.current_data_dir = data_dir if data_dir is not None else os.path.join(scriptdir, "data") - self.current_output_dir = output_dir if output_dir is not None else os.path.join(scriptdir, "outputs") - self.current_logging_dir = logging_dir if logging_dir is not None else os.path.join(scriptdir, "logs") - self.current_reg_data_dir = reg_data_dir if reg_data_dir is not None else os.path.join(scriptdir, "reg") + self.current_output_dir = config.get('output_dir', os.path.join(scriptdir, "outputs")) + self.current_logging_dir = config.get('logging_dir', os.path.join(scriptdir, "logs")) + self.current_reg_data_dir = config.get('reg_data_dir', os.path.join(scriptdir, "reg")) # Create directories if they don't exist self.create_directory_if_not_exists(self.current_output_dir) @@ -44,18 +42,6 @@ class Folders: if directory is not None and directory.strip() != "" and not os.path.exists(directory): os.makedirs(directory, exist_ok=True) - def list_data_dirs(self, path: str) -> list: - """ - List directories in the data directory. - - Parameters: - - path (str): The path to list directories from. - - Returns: - - list: A list of directories. - """ - self.current_data_dir = path - return list(list_dirs(path)) def list_output_dirs(self, path: str) -> list: """ @@ -67,7 +53,7 @@ class Folders: Returns: - list: A list of directories. """ - self.current_output_dir = path + self.current_output_dir = path if not path == "" else "." return list(list_dirs(path)) def list_logging_dirs(self, path: str) -> list: @@ -80,7 +66,7 @@ class Folders: Returns: - list: A list of directories. """ - self.current_logging_dir = path + self.current_logging_dir = path if not path == "" else "." return list(list_dirs(path)) def list_reg_data_dirs(self, path: str) -> list: @@ -93,7 +79,7 @@ class Folders: Returns: - list: A list of directories. """ - self.current_reg_data_dir = path + self.current_reg_data_dir = path if not path == "" else "." return list(list_dirs(path)) def create_folders_gui(self) -> None: @@ -131,7 +117,7 @@ class Folders: allow_custom_value=True, ) # Refresh button for regularisation directory - create_refresh_button(self.reg_data_dir, lambda: None, lambda: {"choices": [""] + self.list_data_dirs(self.current_data_dir)}, "open_folder_small") + create_refresh_button(self.reg_data_dir, lambda: None, lambda: {"choices": [""] + self.list_reg_data_dirs(self.current_reg_data_dir)}, "open_folder_small") # Regularisation directory button self.reg_data_dir_folder = gr.Button( '📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless) @@ -173,7 +159,7 @@ class Folders: ) # Change event for regularisation directory dropdown self.reg_data_dir.change( - fn=lambda path: gr.Dropdown(choices=[""] + self.list_data_dirs(path)), + fn=lambda path: gr.Dropdown(choices=[""] + self.list_reg_data_dirs(path)), inputs=self.reg_data_dir, outputs=self.reg_data_dir, show_progress=False, diff --git a/kohya_gui/class_source_model.py b/kohya_gui/class_source_model.py index 6e98207..f55de89 100644 --- a/kohya_gui/class_source_model.py +++ b/kohya_gui/class_source_model.py @@ -8,35 +8,38 @@ from .common_gui import ( scriptdir, list_dirs, list_files, + create_refresh_button, + load_kohya_ss_gui_config, ) -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +folder_symbol = "\U0001f4c2" # 📂 +refresh_symbol = "\U0001f504" # 🔄 +save_style_symbol = "\U0001f4be" # 💾 +document_symbol = "\U0001F4C4" # 📄 default_models = [ - 'stabilityai/stable-diffusion-xl-base-1.0', - 'stabilityai/stable-diffusion-xl-refiner-1.0', - 'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned', - 'stabilityai/stable-diffusion-2-1-base', - 'stabilityai/stable-diffusion-2-base', - 'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned', - 'stabilityai/stable-diffusion-2-1', - 'stabilityai/stable-diffusion-2', - 'runwayml/stable-diffusion-v1-5', - 'CompVis/stable-diffusion-v1-4', + "stabilityai/stable-diffusion-xl-base-1.0", + "stabilityai/stable-diffusion-xl-refiner-1.0", + "stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned", + "stabilityai/stable-diffusion-2-1-base", + "stabilityai/stable-diffusion-2-base", + "stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned", + "stabilityai/stable-diffusion-2-1", + "stabilityai/stable-diffusion-2", + "runwayml/stable-diffusion-v1-5", + "CompVis/stable-diffusion-v1-4", ] + class SourceModel: def __init__( self, save_model_as_choices=[ - 'same as source model', - 'ckpt', - 'diffusers', - 'diffusers_safetensors', - 'safetensors', + "same as source model", + "ckpt", + "diffusers", + "diffusers_safetensors", + "safetensors", ], save_precision_choices=[ "float", @@ -44,113 +47,146 @@ class SourceModel: "bf16", ], headless=False, - default_data_dir=None, finetuning=False, ): self.headless = headless self.save_model_as_choices = save_model_as_choices self.finetuning = finetuning - from .common_gui import create_refresh_button + config = load_kohya_ss_gui_config() - default_data_dir = default_data_dir if default_data_dir is not None else os.path.join(scriptdir, "outputs") - default_train_dir = default_data_dir if default_data_dir is not None else os.path.join(scriptdir, "data") - model_checkpoints = list(list_files(default_data_dir, exts=[".ckpt", ".safetensors"], all=True)) - self.current_data_dir = default_data_dir - self.current_train_dir = default_train_dir + # Set default directories if not provided + self.current_models_dir = config.get( + "models_dir", os.path.join(scriptdir, "models") + ) + self.current_train_data_dir = config.get( + "train_data_dir", os.path.join(scriptdir, "data") + ) + + model_checkpoints = list( + list_files( + self.current_models_dir, exts=[".ckpt", ".safetensors"], all=True + ) + ) def list_models(path): - self.current_data_dir = path if os.path.isdir(path) else os.path.dirname(path) - return default_models + list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) + self.current_models_dir = ( + path if os.path.isdir(path) else os.path.dirname(path) + ) + return default_models + list( + list_files(path, exts=[".ckpt", ".safetensors"], all=True) + ) - def list_train_dirs(path): - self.current_train_dir = path if os.path.isdir(path) else os.path.dirname(path) + def list_train_data_dirs(path): + self.current_train_data_dir = path if not path == "" else "." return list(list_dirs(path)) - if default_data_dir is not None and default_data_dir.strip() != "" and not os.path.exists(default_data_dir): - os.makedirs(default_data_dir, exist_ok=True) - with gr.Column(), gr.Group(): # Define the input elements with gr.Row(): - with gr.Column(), gr.Row(): - self.model_list = gr.Textbox(visible=False, value="") - self.pretrained_model_name_or_path = gr.Dropdown( - label='Pretrained model name or path', - choices=default_models + model_checkpoints, - value='runwayml/stable-diffusion-v1-5', - allow_custom_value=True, - visible=True, - min_width=100, - ) - create_refresh_button(self.pretrained_model_name_or_path, lambda: None, lambda: {"choices": list_models(self.current_data_dir)},"open_folder_small") + with gr.Column(), gr.Row(): + self.model_list = gr.Textbox(visible=False, value="") + self.pretrained_model_name_or_path = gr.Dropdown( + label="Pretrained model name or path", + choices=default_models + model_checkpoints, + value="runwayml/stable-diffusion-v1-5", + allow_custom_value=True, + visible=True, + min_width=100, + ) + create_refresh_button( + self.pretrained_model_name_or_path, + lambda: None, + lambda: {"choices": list_models(self.current_models_dir)}, + "open_folder_small", + ) - self.pretrained_model_name_or_path_file = gr.Button( - document_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], - visible=(not headless), - ) - self.pretrained_model_name_or_path_file.click( - get_any_file_path, - inputs=self.pretrained_model_name_or_path, - outputs=self.pretrained_model_name_or_path, - show_progress=False, - ) - self.pretrained_model_name_or_path_folder = gr.Button( - folder_symbol, - elem_id='open_folder_small', - elem_classes=['tool'], - visible=(not headless), - ) - self.pretrained_model_name_or_path_folder.click( - get_folder_path, - inputs=self.pretrained_model_name_or_path, - outputs=self.pretrained_model_name_or_path, - show_progress=False, - ) + self.pretrained_model_name_or_path_file = gr.Button( + document_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), + ) + self.pretrained_model_name_or_path_file.click( + get_any_file_path, + inputs=self.pretrained_model_name_or_path, + outputs=self.pretrained_model_name_or_path, + show_progress=False, + ) + self.pretrained_model_name_or_path_folder = gr.Button( + folder_symbol, + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), + ) + self.pretrained_model_name_or_path_folder.click( + get_folder_path, + inputs=self.pretrained_model_name_or_path, + outputs=self.pretrained_model_name_or_path, + show_progress=False, + ) - with gr.Column(), gr.Row(): - self.train_data_dir = gr.Dropdown( - label='Image folder (containing training images subfolders)' if not finetuning else 'Image folder (containing training images)', - choices=[""] + list_train_dirs(default_train_dir), - value="", - interactive=True, - allow_custom_value=True, - ) - create_refresh_button(self.train_data_dir, lambda: None, lambda: {"choices": list_train_dirs(self.current_train_dir)}, "open_folder_small") - self.train_data_dir_folder = gr.Button( - '📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless) - ) - self.train_data_dir_folder.click( - get_folder_path, - outputs=self.train_data_dir, - show_progress=False, - ) + with gr.Column(), gr.Row(): + self.train_data_dir = gr.Dropdown( + label=( + "Image folder (containing training images subfolders)" + if not finetuning + else "Image folder (containing training images)" + ), + choices=[""] + list_train_data_dirs(self.current_train_data_dir), + value="", + interactive=True, + allow_custom_value=True, + ) + create_refresh_button( + self.train_data_dir, + lambda: None, + lambda: {"choices": [""] + list_train_data_dirs(self.current_train_data_dir)}, + "open_folder_small", + ) + self.train_data_dir_folder = gr.Button( + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not self.headless), + ) + self.train_data_dir_folder.click( + get_folder_path, + outputs=self.train_data_dir, + show_progress=False, + ) with gr.Row(): with gr.Column(): with gr.Row(): - self.v2 = gr.Checkbox(label='v2', value=False, visible=False, min_width=60) + self.v2 = gr.Checkbox( + label="v2", value=False, visible=False, min_width=60 + ) self.v_parameterization = gr.Checkbox( - label='v_parameterization', value=False, visible=False, min_width=130, + label="v_parameterization", + value=False, + visible=False, + min_width=130, ) self.sdxl_checkbox = gr.Checkbox( - label='SDXL', value=False, visible=False, min_width=60, + label="SDXL", + value=False, + visible=False, + min_width=60, ) with gr.Column(): gr.Box(visible=False) with gr.Row(): self.output_name = gr.Textbox( - label='Trained Model output name', - placeholder='(Name of the model to output)', - value='last', + label="Trained Model output name", + placeholder="(Name of the model to output)", + value="last", interactive=True, ) self.training_comment = gr.Textbox( - label='Training comment', - placeholder='(Optional) Add training comment to be included in metadata', + label="Training comment", + placeholder="(Optional) Add training comment to be included in metadata", interactive=True, ) @@ -167,7 +203,9 @@ class SourceModel: ) self.pretrained_model_name_or_path.change( - fn=lambda path: set_pretrained_model_name_or_path_input(path, refresh_method=list_models), + fn=lambda path: set_pretrained_model_name_or_path_input( + path, refresh_method=list_models + ), inputs=[ self.pretrained_model_name_or_path, ], @@ -181,7 +219,7 @@ class SourceModel: ) self.train_data_dir.change( - fn=lambda path: gr.Dropdown(choices=[""] + list_train_dirs(path)), + fn=lambda path: gr.Dropdown(choices=[""] + list_train_data_dirs(path)), inputs=self.train_data_dir, outputs=self.train_data_dir, show_progress=False, diff --git a/kohya_gui/common_gui.py b/kohya_gui/common_gui.py index 86421a2..0f157ff 100644 --- a/kohya_gui/common_gui.py +++ b/kohya_gui/common_gui.py @@ -9,6 +9,7 @@ import gradio as gr import shutil import sys import json +import toml # Set up logging log = setup_logging() @@ -55,6 +56,22 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + SDX ENV_EXCLUSION = ["COLAB_GPU", "RUNPOD_POD_ID"] +def load_kohya_ss_gui_config() -> dict: + """ + Loads the Kohya SS GUI configuration from a TOML file. + + Returns: + dict: The configuration data loaded from the TOML file. + """ + try: + # Attempt to load the TOML configuration file from the specified directory. + config = toml.load(fr"{scriptdir}/config.toml") + except FileNotFoundError: + # If the config file is not found, initialize `config` as an empty dictionary to handle missing configurations gracefully. + config = {} + + return config + def check_if_model_exist( output_name: str, output_dir: str, save_model_as: str, headless: bool = False ) -> bool: diff --git a/models/.keep b/models/.keep new file mode 100644 index 0000000..e69de29