Add support for user managed path config

pull/2111/head^2
bmaltais 2024-03-16 20:58:22 -04:00
parent 0e4582cd52
commit 6386a720b8
8 changed files with 188 additions and 127 deletions

2
.gitignore vendored
View File

@ -51,3 +51,5 @@ outputs
dataset/**
!dataset/**/
!dataset/**/.gitkeep
# models
# data

View File

@ -373,6 +373,7 @@ The documentation in this section will be moved to a separate document later.
- Add support for `wandb_run_name`, `log_tracker_name` and `log_tracker_config` parameters under the advanced section.
- Update sd-scripts to v0.8.5
- Improve code
- Add support for custom path defaults. Simply copy the `config example.toml` file found in the root of the repo to `config.toml` and edit the different values to your taste.
### 2024/03/13 (v23.0.11)

13
config example.toml Normal file
View File

@ -0,0 +1,13 @@
# Copy this file and name it config.toml
# Edit the values to suit your needs
# Default folders location
models_dir = "./models" # Pretrained model name or path
train_data_dir = "./data" # Image folder (containing training images subfolders) / Image folder (containing training images)
output_dir = "./outputs" # Output directory for trained model
reg_data_dir = "./data/reg" # Regularisation directory
logging_dir = "./logs" # Logging directory
config_dir = "./presets" # Load/Save Config file
# Example custom folder location
# models_dir = "e:/models" # Pretrained model name or path

View File

@ -1,7 +1,11 @@
import gradio as gr
import os
import toml
from .common_gui import list_files, scriptdir, create_refresh_button, load_kohya_ss_gui_config
from .custom_logging import setup_logging
from .common_gui import list_files, scriptdir, create_refresh_button
# Set up logging
log = setup_logging()
class ConfigurationFile:
@ -20,10 +24,10 @@ class ConfigurationFile:
self.headless = headless
config = load_kohya_ss_gui_config()
# Sets the directory for storing configuration files, defaults to a 'presets' folder within the script directory.
self.current_config_dir = (
config_dir if config_dir is not None else os.path.join(scriptdir, "presets")
)
self.current_config_dir = config.get('config_dir', os.path.join(scriptdir, "presets"))
# Initialize the GUI components for configuration.
self.create_config_gui()
@ -38,7 +42,7 @@ class ConfigurationFile:
Returns:
- list: A list of directories.
"""
self.current_config_dir = path
self.current_config_dir = path if not path == "" else "."
# Lists all .json files in the current configuration directory, used for populating dropdown choices.
return list(list_files(self.current_config_dir, exts=[".json"], all=True))

View File

@ -1,31 +1,29 @@
import gradio as gr
import os
from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button
from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button, load_kohya_ss_gui_config
class Folders:
"""
A class to handle folder operations in the GUI.
"""
def __init__(self, finetune: bool = False, data_dir: str = None, output_dir: str = None, logging_dir: str = None, reg_data_dir: str = None, headless: bool = False):
def __init__(self, finetune: bool = False, headless: bool = False):
"""
Initialize the Folders class.
Parameters:
- finetune (bool): Whether to finetune the model.
- data_dir (str): The directory for data.
- output_dir (str): The directory for output.
- logging_dir (str): The directory for logging.
- reg_data_dir (str): The directory for regularization data.
- headless (bool): Whether to run in headless mode.
"""
self.headless = headless
self.finetune = finetune
# Load kohya_ss GUI configs from config.toml if it exist
config = load_kohya_ss_gui_config()
# Set default directories if not provided
self.current_data_dir = data_dir if data_dir is not None else os.path.join(scriptdir, "data")
self.current_output_dir = output_dir if output_dir is not None else os.path.join(scriptdir, "outputs")
self.current_logging_dir = logging_dir if logging_dir is not None else os.path.join(scriptdir, "logs")
self.current_reg_data_dir = reg_data_dir if reg_data_dir is not None else os.path.join(scriptdir, "reg")
self.current_output_dir = config.get('output_dir', os.path.join(scriptdir, "outputs"))
self.current_logging_dir = config.get('logging_dir', os.path.join(scriptdir, "logs"))
self.current_reg_data_dir = config.get('reg_data_dir', os.path.join(scriptdir, "reg"))
# Create directories if they don't exist
self.create_directory_if_not_exists(self.current_output_dir)
@ -44,18 +42,6 @@ class Folders:
if directory is not None and directory.strip() != "" and not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
def list_data_dirs(self, path: str) -> list:
"""
List directories in the data directory.
Parameters:
- path (str): The path to list directories from.
Returns:
- list: A list of directories.
"""
self.current_data_dir = path
return list(list_dirs(path))
def list_output_dirs(self, path: str) -> list:
"""
@ -67,7 +53,7 @@ class Folders:
Returns:
- list: A list of directories.
"""
self.current_output_dir = path
self.current_output_dir = path if not path == "" else "."
return list(list_dirs(path))
def list_logging_dirs(self, path: str) -> list:
@ -80,7 +66,7 @@ class Folders:
Returns:
- list: A list of directories.
"""
self.current_logging_dir = path
self.current_logging_dir = path if not path == "" else "."
return list(list_dirs(path))
def list_reg_data_dirs(self, path: str) -> list:
@ -93,7 +79,7 @@ class Folders:
Returns:
- list: A list of directories.
"""
self.current_reg_data_dir = path
self.current_reg_data_dir = path if not path == "" else "."
return list(list_dirs(path))
def create_folders_gui(self) -> None:
@ -131,7 +117,7 @@ class Folders:
allow_custom_value=True,
)
# Refresh button for regularisation directory
create_refresh_button(self.reg_data_dir, lambda: None, lambda: {"choices": [""] + self.list_data_dirs(self.current_data_dir)}, "open_folder_small")
create_refresh_button(self.reg_data_dir, lambda: None, lambda: {"choices": [""] + self.list_reg_data_dirs(self.current_reg_data_dir)}, "open_folder_small")
# Regularisation directory button
self.reg_data_dir_folder = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
@ -173,7 +159,7 @@ class Folders:
)
# Change event for regularisation directory dropdown
self.reg_data_dir.change(
fn=lambda path: gr.Dropdown(choices=[""] + self.list_data_dirs(path)),
fn=lambda path: gr.Dropdown(choices=[""] + self.list_reg_data_dirs(path)),
inputs=self.reg_data_dir,
outputs=self.reg_data_dir,
show_progress=False,

View File

@ -8,35 +8,38 @@ from .common_gui import (
scriptdir,
list_dirs,
list_files,
create_refresh_button,
load_kohya_ss_gui_config,
)
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
folder_symbol = "\U0001f4c2" # 📂
refresh_symbol = "\U0001f504" # 🔄
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄
default_models = [
'stabilityai/stable-diffusion-xl-base-1.0',
'stabilityai/stable-diffusion-xl-refiner-1.0',
'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned',
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned',
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
'runwayml/stable-diffusion-v1-5',
'CompVis/stable-diffusion-v1-4',
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/stable-diffusion-xl-refiner-1.0",
"stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned",
"stabilityai/stable-diffusion-2-1-base",
"stabilityai/stable-diffusion-2-base",
"stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2",
"runwayml/stable-diffusion-v1-5",
"CompVis/stable-diffusion-v1-4",
]
class SourceModel:
def __init__(
self,
save_model_as_choices=[
'same as source model',
'ckpt',
'diffusers',
'diffusers_safetensors',
'safetensors',
"same as source model",
"ckpt",
"diffusers",
"diffusers_safetensors",
"safetensors",
],
save_precision_choices=[
"float",
@ -44,113 +47,146 @@ class SourceModel:
"bf16",
],
headless=False,
default_data_dir=None,
finetuning=False,
):
self.headless = headless
self.save_model_as_choices = save_model_as_choices
self.finetuning = finetuning
from .common_gui import create_refresh_button
config = load_kohya_ss_gui_config()
default_data_dir = default_data_dir if default_data_dir is not None else os.path.join(scriptdir, "outputs")
default_train_dir = default_data_dir if default_data_dir is not None else os.path.join(scriptdir, "data")
model_checkpoints = list(list_files(default_data_dir, exts=[".ckpt", ".safetensors"], all=True))
self.current_data_dir = default_data_dir
self.current_train_dir = default_train_dir
# Set default directories if not provided
self.current_models_dir = config.get(
"models_dir", os.path.join(scriptdir, "models")
)
self.current_train_data_dir = config.get(
"train_data_dir", os.path.join(scriptdir, "data")
)
model_checkpoints = list(
list_files(
self.current_models_dir, exts=[".ckpt", ".safetensors"], all=True
)
)
def list_models(path):
self.current_data_dir = path if os.path.isdir(path) else os.path.dirname(path)
return default_models + list(list_files(path, exts=[".ckpt", ".safetensors"], all=True))
self.current_models_dir = (
path if os.path.isdir(path) else os.path.dirname(path)
)
return default_models + list(
list_files(path, exts=[".ckpt", ".safetensors"], all=True)
)
def list_train_dirs(path):
self.current_train_dir = path if os.path.isdir(path) else os.path.dirname(path)
def list_train_data_dirs(path):
self.current_train_data_dir = path if not path == "" else "."
return list(list_dirs(path))
if default_data_dir is not None and default_data_dir.strip() != "" and not os.path.exists(default_data_dir):
os.makedirs(default_data_dir, exist_ok=True)
with gr.Column(), gr.Group():
# Define the input elements
with gr.Row():
with gr.Column(), gr.Row():
self.model_list = gr.Textbox(visible=False, value="")
self.pretrained_model_name_or_path = gr.Dropdown(
label='Pretrained model name or path',
choices=default_models + model_checkpoints,
value='runwayml/stable-diffusion-v1-5',
allow_custom_value=True,
visible=True,
min_width=100,
)
create_refresh_button(self.pretrained_model_name_or_path, lambda: None, lambda: {"choices": list_models(self.current_data_dir)},"open_folder_small")
with gr.Column(), gr.Row():
self.model_list = gr.Textbox(visible=False, value="")
self.pretrained_model_name_or_path = gr.Dropdown(
label="Pretrained model name or path",
choices=default_models + model_checkpoints,
value="runwayml/stable-diffusion-v1-5",
allow_custom_value=True,
visible=True,
min_width=100,
)
create_refresh_button(
self.pretrained_model_name_or_path,
lambda: None,
lambda: {"choices": list_models(self.current_models_dir)},
"open_folder_small",
)
self.pretrained_model_name_or_path_file = gr.Button(
document_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
visible=(not headless),
)
self.pretrained_model_name_or_path_file.click(
get_any_file_path,
inputs=self.pretrained_model_name_or_path,
outputs=self.pretrained_model_name_or_path,
show_progress=False,
)
self.pretrained_model_name_or_path_folder = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_classes=['tool'],
visible=(not headless),
)
self.pretrained_model_name_or_path_folder.click(
get_folder_path,
inputs=self.pretrained_model_name_or_path,
outputs=self.pretrained_model_name_or_path,
show_progress=False,
)
self.pretrained_model_name_or_path_file = gr.Button(
document_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
self.pretrained_model_name_or_path_file.click(
get_any_file_path,
inputs=self.pretrained_model_name_or_path,
outputs=self.pretrained_model_name_or_path,
show_progress=False,
)
self.pretrained_model_name_or_path_folder = gr.Button(
folder_symbol,
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not headless),
)
self.pretrained_model_name_or_path_folder.click(
get_folder_path,
inputs=self.pretrained_model_name_or_path,
outputs=self.pretrained_model_name_or_path,
show_progress=False,
)
with gr.Column(), gr.Row():
self.train_data_dir = gr.Dropdown(
label='Image folder (containing training images subfolders)' if not finetuning else 'Image folder (containing training images)',
choices=[""] + list_train_dirs(default_train_dir),
value="",
interactive=True,
allow_custom_value=True,
)
create_refresh_button(self.train_data_dir, lambda: None, lambda: {"choices": list_train_dirs(self.current_train_dir)}, "open_folder_small")
self.train_data_dir_folder = gr.Button(
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
)
self.train_data_dir_folder.click(
get_folder_path,
outputs=self.train_data_dir,
show_progress=False,
)
with gr.Column(), gr.Row():
self.train_data_dir = gr.Dropdown(
label=(
"Image folder (containing training images subfolders)"
if not finetuning
else "Image folder (containing training images)"
),
choices=[""] + list_train_data_dirs(self.current_train_data_dir),
value="",
interactive=True,
allow_custom_value=True,
)
create_refresh_button(
self.train_data_dir,
lambda: None,
lambda: {"choices": [""] + list_train_data_dirs(self.current_train_data_dir)},
"open_folder_small",
)
self.train_data_dir_folder = gr.Button(
"📂",
elem_id="open_folder_small",
elem_classes=["tool"],
visible=(not self.headless),
)
self.train_data_dir_folder.click(
get_folder_path,
outputs=self.train_data_dir,
show_progress=False,
)
with gr.Row():
with gr.Column():
with gr.Row():
self.v2 = gr.Checkbox(label='v2', value=False, visible=False, min_width=60)
self.v2 = gr.Checkbox(
label="v2", value=False, visible=False, min_width=60
)
self.v_parameterization = gr.Checkbox(
label='v_parameterization', value=False, visible=False, min_width=130,
label="v_parameterization",
value=False,
visible=False,
min_width=130,
)
self.sdxl_checkbox = gr.Checkbox(
label='SDXL', value=False, visible=False, min_width=60,
label="SDXL",
value=False,
visible=False,
min_width=60,
)
with gr.Column():
gr.Box(visible=False)
with gr.Row():
self.output_name = gr.Textbox(
label='Trained Model output name',
placeholder='(Name of the model to output)',
value='last',
label="Trained Model output name",
placeholder="(Name of the model to output)",
value="last",
interactive=True,
)
self.training_comment = gr.Textbox(
label='Training comment',
placeholder='(Optional) Add training comment to be included in metadata',
label="Training comment",
placeholder="(Optional) Add training comment to be included in metadata",
interactive=True,
)
@ -167,7 +203,9 @@ class SourceModel:
)
self.pretrained_model_name_or_path.change(
fn=lambda path: set_pretrained_model_name_or_path_input(path, refresh_method=list_models),
fn=lambda path: set_pretrained_model_name_or_path_input(
path, refresh_method=list_models
),
inputs=[
self.pretrained_model_name_or_path,
],
@ -181,7 +219,7 @@ class SourceModel:
)
self.train_data_dir.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_train_dirs(path)),
fn=lambda path: gr.Dropdown(choices=[""] + list_train_data_dirs(path)),
inputs=self.train_data_dir,
outputs=self.train_data_dir,
show_progress=False,

View File

@ -9,6 +9,7 @@ import gradio as gr
import shutil
import sys
import json
import toml
# Set up logging
log = setup_logging()
@ -55,6 +56,22 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + SDX
ENV_EXCLUSION = ["COLAB_GPU", "RUNPOD_POD_ID"]
def load_kohya_ss_gui_config() -> dict:
"""
Loads the Kohya SS GUI configuration from a TOML file.
Returns:
dict: The configuration data loaded from the TOML file.
"""
try:
# Attempt to load the TOML configuration file from the specified directory.
config = toml.load(fr"{scriptdir}/config.toml")
except FileNotFoundError:
# If the config file is not found, initialize `config` as an empty dictionary to handle missing configurations gracefully.
config = {}
return config
def check_if_model_exist(
output_name: str, output_dir: str, save_model_as: str, headless: bool = False
) -> bool:

0
models/.keep Normal file
View File