Add support for user managed path config

pull/2111/head^2
bmaltais 2024-03-16 20:58:22 -04:00
parent 0e4582cd52
commit 6386a720b8
8 changed files with 188 additions and 127 deletions

2
.gitignore vendored
View File

@ -51,3 +51,5 @@ outputs
dataset/** dataset/**
!dataset/**/ !dataset/**/
!dataset/**/.gitkeep !dataset/**/.gitkeep
# models
# data

View File

@ -373,6 +373,7 @@ The documentation in this section will be moved to a separate document later.
- Add support for `wandb_run_name`, `log_tracker_name` and `log_tracker_config` parameters under the advanced section. - Add support for `wandb_run_name`, `log_tracker_name` and `log_tracker_config` parameters under the advanced section.
- Update sd-scripts to v0.8.5 - Update sd-scripts to v0.8.5
- Improve code - Improve code
- Add support for custom path defaults. Simply copy the `config example.toml` file found in the root of the repo to `config.toml` and edit the different values to your taste.
### 2024/03/13 (v23.0.11) ### 2024/03/13 (v23.0.11)

13
config example.toml Normal file
View File

@ -0,0 +1,13 @@
# Copy this file and name it config.toml
# Edit the values to suit your needs
# Default folders location
models_dir = "./models" # Pretrained model name or path
train_data_dir = "./data" # Image folder (containing training images subfolders) / Image folder (containing training images)
output_dir = "./outputs" # Output directory for trained model
reg_data_dir = "./data/reg" # Regularisation directory
logging_dir = "./logs" # Logging directory
config_dir = "./presets" # Load/Save Config file
# Example custom folder location
# models_dir = "e:/models" # Pretrained model name or path

View File

@ -1,7 +1,11 @@
import gradio as gr import gradio as gr
import os import os
import toml
from .common_gui import list_files, scriptdir, create_refresh_button, load_kohya_ss_gui_config
from .custom_logging import setup_logging
from .common_gui import list_files, scriptdir, create_refresh_button # Set up logging
log = setup_logging()
class ConfigurationFile: class ConfigurationFile:
@ -19,11 +23,11 @@ class ConfigurationFile:
""" """
self.headless = headless self.headless = headless
config = load_kohya_ss_gui_config()
# Sets the directory for storing configuration files, defaults to a 'presets' folder within the script directory. # Sets the directory for storing configuration files, defaults to a 'presets' folder within the script directory.
self.current_config_dir = ( self.current_config_dir = config.get('config_dir', os.path.join(scriptdir, "presets"))
config_dir if config_dir is not None else os.path.join(scriptdir, "presets")
)
# Initialize the GUI components for configuration. # Initialize the GUI components for configuration.
self.create_config_gui() self.create_config_gui()
@ -38,7 +42,7 @@ class ConfigurationFile:
Returns: Returns:
- list: A list of directories. - list: A list of directories.
""" """
self.current_config_dir = path self.current_config_dir = path if not path == "" else "."
# Lists all .json files in the current configuration directory, used for populating dropdown choices. # Lists all .json files in the current configuration directory, used for populating dropdown choices.
return list(list_files(self.current_config_dir, exts=[".json"], all=True)) return list(list_files(self.current_config_dir, exts=[".json"], all=True))

View File

@ -1,31 +1,29 @@
import gradio as gr import gradio as gr
import os import os
from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button, load_kohya_ss_gui_config
class Folders: class Folders:
""" """
A class to handle folder operations in the GUI. A class to handle folder operations in the GUI.
""" """
def __init__(self, finetune: bool = False, data_dir: str = None, output_dir: str = None, logging_dir: str = None, reg_data_dir: str = None, headless: bool = False): def __init__(self, finetune: bool = False, headless: bool = False):
""" """
Initialize the Folders class. Initialize the Folders class.
Parameters: Parameters:
- finetune (bool): Whether to finetune the model. - finetune (bool): Whether to finetune the model.
- data_dir (str): The directory for data.
- output_dir (str): The directory for output.
- logging_dir (str): The directory for logging.
- reg_data_dir (str): The directory for regularization data.
- headless (bool): Whether to run in headless mode. - headless (bool): Whether to run in headless mode.
""" """
self.headless = headless self.headless = headless
self.finetune = finetune self.finetune = finetune
# Load kohya_ss GUI configs from config.toml if it exist
config = load_kohya_ss_gui_config()
# Set default directories if not provided # Set default directories if not provided
self.current_data_dir = data_dir if data_dir is not None else os.path.join(scriptdir, "data") self.current_output_dir = config.get('output_dir', os.path.join(scriptdir, "outputs"))
self.current_output_dir = output_dir if output_dir is not None else os.path.join(scriptdir, "outputs") self.current_logging_dir = config.get('logging_dir', os.path.join(scriptdir, "logs"))
self.current_logging_dir = logging_dir if logging_dir is not None else os.path.join(scriptdir, "logs") self.current_reg_data_dir = config.get('reg_data_dir', os.path.join(scriptdir, "reg"))
self.current_reg_data_dir = reg_data_dir if reg_data_dir is not None else os.path.join(scriptdir, "reg")
# Create directories if they don't exist # Create directories if they don't exist
self.create_directory_if_not_exists(self.current_output_dir) self.create_directory_if_not_exists(self.current_output_dir)
@ -44,18 +42,6 @@ class Folders:
if directory is not None and directory.strip() != "" and not os.path.exists(directory): if directory is not None and directory.strip() != "" and not os.path.exists(directory):
os.makedirs(directory, exist_ok=True) os.makedirs(directory, exist_ok=True)
def list_data_dirs(self, path: str) -> list:
"""
List directories in the data directory.
Parameters:
- path (str): The path to list directories from.
Returns:
- list: A list of directories.
"""
self.current_data_dir = path
return list(list_dirs(path))
def list_output_dirs(self, path: str) -> list: def list_output_dirs(self, path: str) -> list:
""" """
@ -67,7 +53,7 @@ class Folders:
Returns: Returns:
- list: A list of directories. - list: A list of directories.
""" """
self.current_output_dir = path self.current_output_dir = path if not path == "" else "."
return list(list_dirs(path)) return list(list_dirs(path))
def list_logging_dirs(self, path: str) -> list: def list_logging_dirs(self, path: str) -> list:
@ -80,7 +66,7 @@ class Folders:
Returns: Returns:
- list: A list of directories. - list: A list of directories.
""" """
self.current_logging_dir = path self.current_logging_dir = path if not path == "" else "."
return list(list_dirs(path)) return list(list_dirs(path))
def list_reg_data_dirs(self, path: str) -> list: def list_reg_data_dirs(self, path: str) -> list:
@ -93,7 +79,7 @@ class Folders:
Returns: Returns:
- list: A list of directories. - list: A list of directories.
""" """
self.current_reg_data_dir = path self.current_reg_data_dir = path if not path == "" else "."
return list(list_dirs(path)) return list(list_dirs(path))
def create_folders_gui(self) -> None: def create_folders_gui(self) -> None:
@ -131,7 +117,7 @@ class Folders:
allow_custom_value=True, allow_custom_value=True,
) )
# Refresh button for regularisation directory # Refresh button for regularisation directory
create_refresh_button(self.reg_data_dir, lambda: None, lambda: {"choices": [""] + self.list_data_dirs(self.current_data_dir)}, "open_folder_small") create_refresh_button(self.reg_data_dir, lambda: None, lambda: {"choices": [""] + self.list_reg_data_dirs(self.current_reg_data_dir)}, "open_folder_small")
# Regularisation directory button # Regularisation directory button
self.reg_data_dir_folder = gr.Button( self.reg_data_dir_folder = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless) '📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
@ -173,7 +159,7 @@ class Folders:
) )
# Change event for regularisation directory dropdown # Change event for regularisation directory dropdown
self.reg_data_dir.change( self.reg_data_dir.change(
fn=lambda path: gr.Dropdown(choices=[""] + self.list_data_dirs(path)), fn=lambda path: gr.Dropdown(choices=[""] + self.list_reg_data_dirs(path)),
inputs=self.reg_data_dir, inputs=self.reg_data_dir,
outputs=self.reg_data_dir, outputs=self.reg_data_dir,
show_progress=False, show_progress=False,

View File

@ -8,35 +8,38 @@ from .common_gui import (
scriptdir, scriptdir,
list_dirs, list_dirs,
list_files, list_files,
create_refresh_button,
load_kohya_ss_gui_config,
) )
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = "\U0001f4c2" # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = "\U0001f504" # 🔄
save_style_symbol = '\U0001f4be' # 💾 save_style_symbol = "\U0001f4be" # 💾
document_symbol = '\U0001F4C4' # 📄 document_symbol = "\U0001F4C4" # 📄
default_models = [ default_models = [
'stabilityai/stable-diffusion-xl-base-1.0', "stabilityai/stable-diffusion-xl-base-1.0",
'stabilityai/stable-diffusion-xl-refiner-1.0', "stabilityai/stable-diffusion-xl-refiner-1.0",
'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned', "stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned",
'stabilityai/stable-diffusion-2-1-base', "stabilityai/stable-diffusion-2-1-base",
'stabilityai/stable-diffusion-2-base', "stabilityai/stable-diffusion-2-base",
'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned', "stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned",
'stabilityai/stable-diffusion-2-1', "stabilityai/stable-diffusion-2-1",
'stabilityai/stable-diffusion-2', "stabilityai/stable-diffusion-2",
'runwayml/stable-diffusion-v1-5', "runwayml/stable-diffusion-v1-5",
'CompVis/stable-diffusion-v1-4', "CompVis/stable-diffusion-v1-4",
] ]
class SourceModel: class SourceModel:
def __init__( def __init__(
self, self,
save_model_as_choices=[ save_model_as_choices=[
'same as source model', "same as source model",
'ckpt', "ckpt",
'diffusers', "diffusers",
'diffusers_safetensors', "diffusers_safetensors",
'safetensors', "safetensors",
], ],
save_precision_choices=[ save_precision_choices=[
"float", "float",
@ -44,113 +47,146 @@ class SourceModel:
"bf16", "bf16",
], ],
headless=False, headless=False,
default_data_dir=None,
finetuning=False, finetuning=False,
): ):
self.headless = headless self.headless = headless
self.save_model_as_choices = save_model_as_choices self.save_model_as_choices = save_model_as_choices
self.finetuning = finetuning self.finetuning = finetuning
from .common_gui import create_refresh_button config = load_kohya_ss_gui_config()
default_data_dir = default_data_dir if default_data_dir is not None else os.path.join(scriptdir, "outputs") # Set default directories if not provided
default_train_dir = default_data_dir if default_data_dir is not None else os.path.join(scriptdir, "data") self.current_models_dir = config.get(
model_checkpoints = list(list_files(default_data_dir, exts=[".ckpt", ".safetensors"], all=True)) "models_dir", os.path.join(scriptdir, "models")
self.current_data_dir = default_data_dir )
self.current_train_dir = default_train_dir self.current_train_data_dir = config.get(
"train_data_dir", os.path.join(scriptdir, "data")
)
model_checkpoints = list(
list_files(
self.current_models_dir, exts=[".ckpt", ".safetensors"], all=True
)
)
def list_models(path): def list_models(path):
self.current_data_dir = path if os.path.isdir(path) else os.path.dirname(path) self.current_models_dir = (
return default_models + list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) path if os.path.isdir(path) else os.path.dirname(path)
)
return default_models + list(
list_files(path, exts=[".ckpt", ".safetensors"], all=True)
)
def list_train_dirs(path): def list_train_data_dirs(path):
self.current_train_dir = path if os.path.isdir(path) else os.path.dirname(path) self.current_train_data_dir = path if not path == "" else "."
return list(list_dirs(path)) return list(list_dirs(path))
if default_data_dir is not None and default_data_dir.strip() != "" and not os.path.exists(default_data_dir):
os.makedirs(default_data_dir, exist_ok=True)
with gr.Column(), gr.Group(): with gr.Column(), gr.Group():
# Define the input elements # Define the input elements
with gr.Row(): with gr.Row():
with gr.Column(), gr.Row(): with gr.Column(), gr.Row():
self.model_list = gr.Textbox(visible=False, value="") self.model_list = gr.Textbox(visible=False, value="")
self.pretrained_model_name_or_path = gr.Dropdown( self.pretrained_model_name_or_path = gr.Dropdown(
label='Pretrained model name or path', label="Pretrained model name or path",
choices=default_models + model_checkpoints, choices=default_models + model_checkpoints,
value='runwayml/stable-diffusion-v1-5', value="runwayml/stable-diffusion-v1-5",
allow_custom_value=True, allow_custom_value=True,
visible=True, visible=True,
min_width=100, min_width=100,
) )
create_refresh_button(self.pretrained_model_name_or_path, lambda: None, lambda: {"choices": list_models(self.current_data_dir)},"open_folder_small") create_refresh_button(
self.pretrained_model_name_or_path,
lambda: None,
lambda: {"choices": list_models(self.current_models_dir)},
"open_folder_small",
)
self.pretrained_model_name_or_path_file = gr.Button( self.pretrained_model_name_or_path_file = gr.Button(
document_symbol, document_symbol,
elem_id='open_folder_small', elem_id="open_folder_small",
elem_classes=['tool'], elem_classes=["tool"],
visible=(not headless), visible=(not headless),
) )
self.pretrained_model_name_or_path_file.click( self.pretrained_model_name_or_path_file.click(
get_any_file_path, get_any_file_path,
inputs=self.pretrained_model_name_or_path, inputs=self.pretrained_model_name_or_path,
outputs=self.pretrained_model_name_or_path, outputs=self.pretrained_model_name_or_path,
show_progress=False, show_progress=False,
) )
self.pretrained_model_name_or_path_folder = gr.Button( self.pretrained_model_name_or_path_folder = gr.Button(
folder_symbol, folder_symbol,
elem_id='open_folder_small', elem_id="open_folder_small",
elem_classes=['tool'], elem_classes=["tool"],
visible=(not headless), visible=(not headless),
) )
self.pretrained_model_name_or_path_folder.click( self.pretrained_model_name_or_path_folder.click(
get_folder_path, get_folder_path,
inputs=self.pretrained_model_name_or_path, inputs=self.pretrained_model_name_or_path,
outputs=self.pretrained_model_name_or_path, outputs=self.pretrained_model_name_or_path,
show_progress=False, show_progress=False,
) )
with gr.Column(), gr.Row(): with gr.Column(), gr.Row():
self.train_data_dir = gr.Dropdown( self.train_data_dir = gr.Dropdown(
label='Image folder (containing training images subfolders)' if not finetuning else 'Image folder (containing training images)', label=(
choices=[""] + list_train_dirs(default_train_dir), "Image folder (containing training images subfolders)"
value="", if not finetuning
interactive=True, else "Image folder (containing training images)"
allow_custom_value=True, ),
) choices=[""] + list_train_data_dirs(self.current_train_data_dir),
create_refresh_button(self.train_data_dir, lambda: None, lambda: {"choices": list_train_dirs(self.current_train_dir)}, "open_folder_small") value="",
self.train_data_dir_folder = gr.Button( interactive=True,
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless) allow_custom_value=True,
) )
self.train_data_dir_folder.click( create_refresh_button(
get_folder_path, self.train_data_dir,
outputs=self.train_data_dir, lambda: None,
show_progress=False, lambda: {"choices": [""] + list_train_data_dirs(self.current_train_data_dir)},
) "open_folder_small",
)
self.train_data_dir_folder = gr.Button(
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
self.train_data_dir_folder.click(
get_folder_path,
outputs=self.train_data_dir,
show_progress=False,
)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
self.v2 = gr.Checkbox(label='v2', value=False, visible=False, min_width=60) self.v2 = gr.Checkbox(
label="v2", value=False, visible=False, min_width=60
)
self.v_parameterization = gr.Checkbox( self.v_parameterization = gr.Checkbox(
label='v_parameterization', value=False, visible=False, min_width=130, label="v_parameterization",
value=False,
visible=False,
min_width=130,
) )
self.sdxl_checkbox = gr.Checkbox( self.sdxl_checkbox = gr.Checkbox(
label='SDXL', value=False, visible=False, min_width=60, label="SDXL",
value=False,
visible=False,
min_width=60,
) )
with gr.Column(): with gr.Column():
gr.Box(visible=False) gr.Box(visible=False)
with gr.Row(): with gr.Row():
self.output_name = gr.Textbox( self.output_name = gr.Textbox(
label='Trained Model output name', label="Trained Model output name",
placeholder='(Name of the model to output)', placeholder="(Name of the model to output)",
value='last', value="last",
interactive=True, interactive=True,
) )
self.training_comment = gr.Textbox( self.training_comment = gr.Textbox(
label='Training comment', label="Training comment",
placeholder='(Optional) Add training comment to be included in metadata', placeholder="(Optional) Add training comment to be included in metadata",
interactive=True, interactive=True,
) )
@ -167,7 +203,9 @@ class SourceModel:
) )
self.pretrained_model_name_or_path.change( self.pretrained_model_name_or_path.change(
fn=lambda path: set_pretrained_model_name_or_path_input(path, refresh_method=list_models), fn=lambda path: set_pretrained_model_name_or_path_input(
path, refresh_method=list_models
),
inputs=[ inputs=[
self.pretrained_model_name_or_path, self.pretrained_model_name_or_path,
], ],
@ -181,7 +219,7 @@ class SourceModel:
) )
self.train_data_dir.change( self.train_data_dir.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_train_dirs(path)), fn=lambda path: gr.Dropdown(choices=[""] + list_train_data_dirs(path)),
inputs=self.train_data_dir, inputs=self.train_data_dir,
outputs=self.train_data_dir, outputs=self.train_data_dir,
show_progress=False, show_progress=False,

View File

@ -9,6 +9,7 @@ import gradio as gr
import shutil import shutil
import sys import sys
import json import json
import toml
# Set up logging # Set up logging
log = setup_logging() log = setup_logging()
@ -55,6 +56,22 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + SDX
ENV_EXCLUSION = ["COLAB_GPU", "RUNPOD_POD_ID"] ENV_EXCLUSION = ["COLAB_GPU", "RUNPOD_POD_ID"]
def load_kohya_ss_gui_config() -> dict:
"""
Loads the Kohya SS GUI configuration from a TOML file.
Returns:
dict: The configuration data loaded from the TOML file.
"""
try:
# Attempt to load the TOML configuration file from the specified directory.
config = toml.load(fr"{scriptdir}/config.toml")
except FileNotFoundError:
# If the config file is not found, initialize `config` as an empty dictionary to handle missing configurations gracefully.
config = {}
return config
def check_if_model_exist( def check_if_model_exist(
output_name: str, output_dir: str, save_model_as: str, headless: bool = False output_name: str, output_dir: str, save_model_as: str, headless: bool = False
) -> bool: ) -> bool:

0
models/.keep Normal file
View File