mirror of https://github.com/bmaltais/kohya_ss
Add support for user managed path config
parent
0e4582cd52
commit
6386a720b8
|
|
@ -51,3 +51,5 @@ outputs
|
|||
dataset/**
|
||||
!dataset/**/
|
||||
!dataset/**/.gitkeep
|
||||
# models
|
||||
# data
|
||||
|
|
|
|||
|
|
@ -373,6 +373,7 @@ The documentation in this section will be moved to a separate document later.
|
|||
- Add support for `wandb_run_name`, `log_tracker_name` and `log_tracker_config` parameters under the advanced section.
|
||||
- Update sd-scripts to v0.8.5
|
||||
- Improve code
|
||||
- Add support for custom path defaults. Simply copy the `config example.toml` file found in the root of the repo to `config.toml` and edit the different values to your taste.
|
||||
|
||||
### 2024/03/13 (v23.0.11)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
# Copy this file and name it config.toml
|
||||
# Edit the values to suit your needs
|
||||
|
||||
# Default folders location
|
||||
models_dir = "./models" # Pretrained model name or path
|
||||
train_data_dir = "./data" # Image folder (containing training images subfolders) / Image folder (containing training images)
|
||||
output_dir = "./outputs" # Output directory for trained model
|
||||
reg_data_dir = "./data/reg" # Regularisation directory
|
||||
logging_dir = "./logs" # Logging directory
|
||||
config_dir = "./presets" # Load/Save Config file
|
||||
|
||||
# Example custom folder location
|
||||
# models_dir = "e:/models" # Pretrained model name or path
|
||||
|
|
@ -1,7 +1,11 @@
|
|||
import gradio as gr
|
||||
import os
|
||||
import toml
|
||||
from .common_gui import list_files, scriptdir, create_refresh_button, load_kohya_ss_gui_config
|
||||
from .custom_logging import setup_logging
|
||||
|
||||
from .common_gui import list_files, scriptdir, create_refresh_button
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
||||
|
||||
class ConfigurationFile:
|
||||
|
|
@ -19,11 +23,11 @@ class ConfigurationFile:
|
|||
"""
|
||||
|
||||
self.headless = headless
|
||||
|
||||
config = load_kohya_ss_gui_config()
|
||||
|
||||
# Sets the directory for storing configuration files, defaults to a 'presets' folder within the script directory.
|
||||
self.current_config_dir = (
|
||||
config_dir if config_dir is not None else os.path.join(scriptdir, "presets")
|
||||
)
|
||||
self.current_config_dir = config.get('config_dir', os.path.join(scriptdir, "presets"))
|
||||
|
||||
# Initialize the GUI components for configuration.
|
||||
self.create_config_gui()
|
||||
|
|
@ -38,7 +42,7 @@ class ConfigurationFile:
|
|||
Returns:
|
||||
- list: A list of directories.
|
||||
"""
|
||||
self.current_config_dir = path
|
||||
self.current_config_dir = path if not path == "" else "."
|
||||
# Lists all .json files in the current configuration directory, used for populating dropdown choices.
|
||||
return list(list_files(self.current_config_dir, exts=[".json"], all=True))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,31 +1,29 @@
|
|||
import gradio as gr
|
||||
import os
|
||||
from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button
|
||||
from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button, load_kohya_ss_gui_config
|
||||
|
||||
class Folders:
|
||||
"""
|
||||
A class to handle folder operations in the GUI.
|
||||
"""
|
||||
def __init__(self, finetune: bool = False, data_dir: str = None, output_dir: str = None, logging_dir: str = None, reg_data_dir: str = None, headless: bool = False):
|
||||
def __init__(self, finetune: bool = False, headless: bool = False):
|
||||
"""
|
||||
Initialize the Folders class.
|
||||
|
||||
Parameters:
|
||||
- finetune (bool): Whether to finetune the model.
|
||||
- data_dir (str): The directory for data.
|
||||
- output_dir (str): The directory for output.
|
||||
- logging_dir (str): The directory for logging.
|
||||
- reg_data_dir (str): The directory for regularization data.
|
||||
- headless (bool): Whether to run in headless mode.
|
||||
"""
|
||||
self.headless = headless
|
||||
self.finetune = finetune
|
||||
|
||||
# Load kohya_ss GUI configs from config.toml if it exist
|
||||
config = load_kohya_ss_gui_config()
|
||||
|
||||
# Set default directories if not provided
|
||||
self.current_data_dir = data_dir if data_dir is not None else os.path.join(scriptdir, "data")
|
||||
self.current_output_dir = output_dir if output_dir is not None else os.path.join(scriptdir, "outputs")
|
||||
self.current_logging_dir = logging_dir if logging_dir is not None else os.path.join(scriptdir, "logs")
|
||||
self.current_reg_data_dir = reg_data_dir if reg_data_dir is not None else os.path.join(scriptdir, "reg")
|
||||
self.current_output_dir = config.get('output_dir', os.path.join(scriptdir, "outputs"))
|
||||
self.current_logging_dir = config.get('logging_dir', os.path.join(scriptdir, "logs"))
|
||||
self.current_reg_data_dir = config.get('reg_data_dir', os.path.join(scriptdir, "reg"))
|
||||
|
||||
# Create directories if they don't exist
|
||||
self.create_directory_if_not_exists(self.current_output_dir)
|
||||
|
|
@ -44,18 +42,6 @@ class Folders:
|
|||
if directory is not None and directory.strip() != "" and not os.path.exists(directory):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
def list_data_dirs(self, path: str) -> list:
|
||||
"""
|
||||
List directories in the data directory.
|
||||
|
||||
Parameters:
|
||||
- path (str): The path to list directories from.
|
||||
|
||||
Returns:
|
||||
- list: A list of directories.
|
||||
"""
|
||||
self.current_data_dir = path
|
||||
return list(list_dirs(path))
|
||||
|
||||
def list_output_dirs(self, path: str) -> list:
|
||||
"""
|
||||
|
|
@ -67,7 +53,7 @@ class Folders:
|
|||
Returns:
|
||||
- list: A list of directories.
|
||||
"""
|
||||
self.current_output_dir = path
|
||||
self.current_output_dir = path if not path == "" else "."
|
||||
return list(list_dirs(path))
|
||||
|
||||
def list_logging_dirs(self, path: str) -> list:
|
||||
|
|
@ -80,7 +66,7 @@ class Folders:
|
|||
Returns:
|
||||
- list: A list of directories.
|
||||
"""
|
||||
self.current_logging_dir = path
|
||||
self.current_logging_dir = path if not path == "" else "."
|
||||
return list(list_dirs(path))
|
||||
|
||||
def list_reg_data_dirs(self, path: str) -> list:
|
||||
|
|
@ -93,7 +79,7 @@ class Folders:
|
|||
Returns:
|
||||
- list: A list of directories.
|
||||
"""
|
||||
self.current_reg_data_dir = path
|
||||
self.current_reg_data_dir = path if not path == "" else "."
|
||||
return list(list_dirs(path))
|
||||
|
||||
def create_folders_gui(self) -> None:
|
||||
|
|
@ -131,7 +117,7 @@ class Folders:
|
|||
allow_custom_value=True,
|
||||
)
|
||||
# Refresh button for regularisation directory
|
||||
create_refresh_button(self.reg_data_dir, lambda: None, lambda: {"choices": [""] + self.list_data_dirs(self.current_data_dir)}, "open_folder_small")
|
||||
create_refresh_button(self.reg_data_dir, lambda: None, lambda: {"choices": [""] + self.list_reg_data_dirs(self.current_reg_data_dir)}, "open_folder_small")
|
||||
# Regularisation directory button
|
||||
self.reg_data_dir_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
|
||||
|
|
@ -173,7 +159,7 @@ class Folders:
|
|||
)
|
||||
# Change event for regularisation directory dropdown
|
||||
self.reg_data_dir.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + self.list_data_dirs(path)),
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + self.list_reg_data_dirs(path)),
|
||||
inputs=self.reg_data_dir,
|
||||
outputs=self.reg_data_dir,
|
||||
show_progress=False,
|
||||
|
|
|
|||
|
|
@ -8,35 +8,38 @@ from .common_gui import (
|
|||
scriptdir,
|
||||
list_dirs,
|
||||
list_files,
|
||||
create_refresh_button,
|
||||
load_kohya_ss_gui_config,
|
||||
)
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
folder_symbol = "\U0001f4c2" # 📂
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
save_style_symbol = "\U0001f4be" # 💾
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
|
||||
default_models = [
|
||||
'stabilityai/stable-diffusion-xl-base-1.0',
|
||||
'stabilityai/stable-diffusion-xl-refiner-1.0',
|
||||
'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned',
|
||||
'stabilityai/stable-diffusion-2-1-base',
|
||||
'stabilityai/stable-diffusion-2-base',
|
||||
'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned',
|
||||
'stabilityai/stable-diffusion-2-1',
|
||||
'stabilityai/stable-diffusion-2',
|
||||
'runwayml/stable-diffusion-v1-5',
|
||||
'CompVis/stable-diffusion-v1-4',
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
"stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"stabilityai/stable-diffusion-2-base",
|
||||
"stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-2",
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
]
|
||||
|
||||
|
||||
class SourceModel:
|
||||
def __init__(
|
||||
self,
|
||||
save_model_as_choices=[
|
||||
'same as source model',
|
||||
'ckpt',
|
||||
'diffusers',
|
||||
'diffusers_safetensors',
|
||||
'safetensors',
|
||||
"same as source model",
|
||||
"ckpt",
|
||||
"diffusers",
|
||||
"diffusers_safetensors",
|
||||
"safetensors",
|
||||
],
|
||||
save_precision_choices=[
|
||||
"float",
|
||||
|
|
@ -44,113 +47,146 @@ class SourceModel:
|
|||
"bf16",
|
||||
],
|
||||
headless=False,
|
||||
default_data_dir=None,
|
||||
finetuning=False,
|
||||
):
|
||||
self.headless = headless
|
||||
self.save_model_as_choices = save_model_as_choices
|
||||
self.finetuning = finetuning
|
||||
|
||||
from .common_gui import create_refresh_button
|
||||
config = load_kohya_ss_gui_config()
|
||||
|
||||
default_data_dir = default_data_dir if default_data_dir is not None else os.path.join(scriptdir, "outputs")
|
||||
default_train_dir = default_data_dir if default_data_dir is not None else os.path.join(scriptdir, "data")
|
||||
model_checkpoints = list(list_files(default_data_dir, exts=[".ckpt", ".safetensors"], all=True))
|
||||
self.current_data_dir = default_data_dir
|
||||
self.current_train_dir = default_train_dir
|
||||
# Set default directories if not provided
|
||||
self.current_models_dir = config.get(
|
||||
"models_dir", os.path.join(scriptdir, "models")
|
||||
)
|
||||
self.current_train_data_dir = config.get(
|
||||
"train_data_dir", os.path.join(scriptdir, "data")
|
||||
)
|
||||
|
||||
model_checkpoints = list(
|
||||
list_files(
|
||||
self.current_models_dir, exts=[".ckpt", ".safetensors"], all=True
|
||||
)
|
||||
)
|
||||
|
||||
def list_models(path):
|
||||
self.current_data_dir = path if os.path.isdir(path) else os.path.dirname(path)
|
||||
return default_models + list(list_files(path, exts=[".ckpt", ".safetensors"], all=True))
|
||||
self.current_models_dir = (
|
||||
path if os.path.isdir(path) else os.path.dirname(path)
|
||||
)
|
||||
return default_models + list(
|
||||
list_files(path, exts=[".ckpt", ".safetensors"], all=True)
|
||||
)
|
||||
|
||||
def list_train_dirs(path):
|
||||
self.current_train_dir = path if os.path.isdir(path) else os.path.dirname(path)
|
||||
def list_train_data_dirs(path):
|
||||
self.current_train_data_dir = path if not path == "" else "."
|
||||
return list(list_dirs(path))
|
||||
|
||||
if default_data_dir is not None and default_data_dir.strip() != "" and not os.path.exists(default_data_dir):
|
||||
os.makedirs(default_data_dir, exist_ok=True)
|
||||
|
||||
with gr.Column(), gr.Group():
|
||||
# Define the input elements
|
||||
with gr.Row():
|
||||
with gr.Column(), gr.Row():
|
||||
self.model_list = gr.Textbox(visible=False, value="")
|
||||
self.pretrained_model_name_or_path = gr.Dropdown(
|
||||
label='Pretrained model name or path',
|
||||
choices=default_models + model_checkpoints,
|
||||
value='runwayml/stable-diffusion-v1-5',
|
||||
allow_custom_value=True,
|
||||
visible=True,
|
||||
min_width=100,
|
||||
)
|
||||
create_refresh_button(self.pretrained_model_name_or_path, lambda: None, lambda: {"choices": list_models(self.current_data_dir)},"open_folder_small")
|
||||
with gr.Column(), gr.Row():
|
||||
self.model_list = gr.Textbox(visible=False, value="")
|
||||
self.pretrained_model_name_or_path = gr.Dropdown(
|
||||
label="Pretrained model name or path",
|
||||
choices=default_models + model_checkpoints,
|
||||
value="runwayml/stable-diffusion-v1-5",
|
||||
allow_custom_value=True,
|
||||
visible=True,
|
||||
min_width=100,
|
||||
)
|
||||
create_refresh_button(
|
||||
self.pretrained_model_name_or_path,
|
||||
lambda: None,
|
||||
lambda: {"choices": list_models(self.current_models_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
|
||||
self.pretrained_model_name_or_path_file = gr.Button(
|
||||
document_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
visible=(not headless),
|
||||
)
|
||||
self.pretrained_model_name_or_path_file.click(
|
||||
get_any_file_path,
|
||||
inputs=self.pretrained_model_name_or_path,
|
||||
outputs=self.pretrained_model_name_or_path,
|
||||
show_progress=False,
|
||||
)
|
||||
self.pretrained_model_name_or_path_folder = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_classes=['tool'],
|
||||
visible=(not headless),
|
||||
)
|
||||
self.pretrained_model_name_or_path_folder.click(
|
||||
get_folder_path,
|
||||
inputs=self.pretrained_model_name_or_path,
|
||||
outputs=self.pretrained_model_name_or_path,
|
||||
show_progress=False,
|
||||
)
|
||||
self.pretrained_model_name_or_path_file = gr.Button(
|
||||
document_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
self.pretrained_model_name_or_path_file.click(
|
||||
get_any_file_path,
|
||||
inputs=self.pretrained_model_name_or_path,
|
||||
outputs=self.pretrained_model_name_or_path,
|
||||
show_progress=False,
|
||||
)
|
||||
self.pretrained_model_name_or_path_folder = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not headless),
|
||||
)
|
||||
self.pretrained_model_name_or_path_folder.click(
|
||||
get_folder_path,
|
||||
inputs=self.pretrained_model_name_or_path,
|
||||
outputs=self.pretrained_model_name_or_path,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Column(), gr.Row():
|
||||
self.train_data_dir = gr.Dropdown(
|
||||
label='Image folder (containing training images subfolders)' if not finetuning else 'Image folder (containing training images)',
|
||||
choices=[""] + list_train_dirs(default_train_dir),
|
||||
value="",
|
||||
interactive=True,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(self.train_data_dir, lambda: None, lambda: {"choices": list_train_dirs(self.current_train_dir)}, "open_folder_small")
|
||||
self.train_data_dir_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
|
||||
)
|
||||
self.train_data_dir_folder.click(
|
||||
get_folder_path,
|
||||
outputs=self.train_data_dir,
|
||||
show_progress=False,
|
||||
)
|
||||
with gr.Column(), gr.Row():
|
||||
self.train_data_dir = gr.Dropdown(
|
||||
label=(
|
||||
"Image folder (containing training images subfolders)"
|
||||
if not finetuning
|
||||
else "Image folder (containing training images)"
|
||||
),
|
||||
choices=[""] + list_train_data_dirs(self.current_train_data_dir),
|
||||
value="",
|
||||
interactive=True,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
create_refresh_button(
|
||||
self.train_data_dir,
|
||||
lambda: None,
|
||||
lambda: {"choices": [""] + list_train_data_dirs(self.current_train_data_dir)},
|
||||
"open_folder_small",
|
||||
)
|
||||
self.train_data_dir_folder = gr.Button(
|
||||
"📂",
|
||||
elem_id="open_folder_small",
|
||||
elem_classes=["tool"],
|
||||
visible=(not self.headless),
|
||||
)
|
||||
self.train_data_dir_folder.click(
|
||||
get_folder_path,
|
||||
outputs=self.train_data_dir,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
self.v2 = gr.Checkbox(label='v2', value=False, visible=False, min_width=60)
|
||||
self.v2 = gr.Checkbox(
|
||||
label="v2", value=False, visible=False, min_width=60
|
||||
)
|
||||
self.v_parameterization = gr.Checkbox(
|
||||
label='v_parameterization', value=False, visible=False, min_width=130,
|
||||
label="v_parameterization",
|
||||
value=False,
|
||||
visible=False,
|
||||
min_width=130,
|
||||
)
|
||||
self.sdxl_checkbox = gr.Checkbox(
|
||||
label='SDXL', value=False, visible=False, min_width=60,
|
||||
label="SDXL",
|
||||
value=False,
|
||||
visible=False,
|
||||
min_width=60,
|
||||
)
|
||||
with gr.Column():
|
||||
gr.Box(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
self.output_name = gr.Textbox(
|
||||
label='Trained Model output name',
|
||||
placeholder='(Name of the model to output)',
|
||||
value='last',
|
||||
label="Trained Model output name",
|
||||
placeholder="(Name of the model to output)",
|
||||
value="last",
|
||||
interactive=True,
|
||||
)
|
||||
self.training_comment = gr.Textbox(
|
||||
label='Training comment',
|
||||
placeholder='(Optional) Add training comment to be included in metadata',
|
||||
label="Training comment",
|
||||
placeholder="(Optional) Add training comment to be included in metadata",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
|
|
@ -167,7 +203,9 @@ class SourceModel:
|
|||
)
|
||||
|
||||
self.pretrained_model_name_or_path.change(
|
||||
fn=lambda path: set_pretrained_model_name_or_path_input(path, refresh_method=list_models),
|
||||
fn=lambda path: set_pretrained_model_name_or_path_input(
|
||||
path, refresh_method=list_models
|
||||
),
|
||||
inputs=[
|
||||
self.pretrained_model_name_or_path,
|
||||
],
|
||||
|
|
@ -181,7 +219,7 @@ class SourceModel:
|
|||
)
|
||||
|
||||
self.train_data_dir.change(
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_train_dirs(path)),
|
||||
fn=lambda path: gr.Dropdown(choices=[""] + list_train_data_dirs(path)),
|
||||
inputs=self.train_data_dir,
|
||||
outputs=self.train_data_dir,
|
||||
show_progress=False,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import gradio as gr
|
|||
import shutil
|
||||
import sys
|
||||
import json
|
||||
import toml
|
||||
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
|
|
@ -55,6 +56,22 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + SDX
|
|||
ENV_EXCLUSION = ["COLAB_GPU", "RUNPOD_POD_ID"]
|
||||
|
||||
|
||||
def load_kohya_ss_gui_config() -> dict:
|
||||
"""
|
||||
Loads the Kohya SS GUI configuration from a TOML file.
|
||||
|
||||
Returns:
|
||||
dict: The configuration data loaded from the TOML file.
|
||||
"""
|
||||
try:
|
||||
# Attempt to load the TOML configuration file from the specified directory.
|
||||
config = toml.load(fr"{scriptdir}/config.toml")
|
||||
except FileNotFoundError:
|
||||
# If the config file is not found, initialize `config` as an empty dictionary to handle missing configurations gracefully.
|
||||
config = {}
|
||||
|
||||
return config
|
||||
|
||||
def check_if_model_exist(
|
||||
output_name: str, output_dir: str, save_model_as: str, headless: bool = False
|
||||
) -> bool:
|
||||
|
|
|
|||
Loading…
Reference in New Issue