mirror of https://github.com/bmaltais/kohya_ss
174 lines
6.9 KiB
Python
174 lines
6.9 KiB
Python
import gradio as gr
|
|
import os
|
|
from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button, load_kohya_ss_gui_config
|
|
|
|
class Folders:
|
|
"""
|
|
A class to handle folder operations in the GUI.
|
|
"""
|
|
def __init__(self, finetune: bool = False, headless: bool = False):
|
|
"""
|
|
Initialize the Folders class.
|
|
|
|
Parameters:
|
|
- finetune (bool): Whether to finetune the model.
|
|
- headless (bool): Whether to run in headless mode.
|
|
"""
|
|
self.headless = headless
|
|
self.finetune = finetune
|
|
|
|
# Load kohya_ss GUI configs from config.toml if it exist
|
|
config = load_kohya_ss_gui_config()
|
|
|
|
# Set default directories if not provided
|
|
self.current_output_dir = config.get('output_dir', os.path.join(scriptdir, "outputs"))
|
|
self.current_logging_dir = config.get('logging_dir', os.path.join(scriptdir, "logs"))
|
|
self.current_reg_data_dir = config.get('reg_data_dir', os.path.join(scriptdir, "reg"))
|
|
|
|
# Create directories if they don't exist
|
|
self.create_directory_if_not_exists(self.current_output_dir)
|
|
self.create_directory_if_not_exists(self.current_logging_dir)
|
|
|
|
# Create the GUI for folder selection
|
|
self.create_folders_gui()
|
|
|
|
def create_directory_if_not_exists(self, directory: str) -> None:
|
|
"""
|
|
Create a directory if it does not exist.
|
|
|
|
Parameters:
|
|
- directory (str): The directory to create.
|
|
"""
|
|
if directory is not None and directory.strip() != "" and not os.path.exists(directory):
|
|
os.makedirs(directory, exist_ok=True)
|
|
|
|
|
|
def list_output_dirs(self, path: str) -> list:
|
|
"""
|
|
List directories in the output directory.
|
|
|
|
Parameters:
|
|
- path (str): The path to list directories from.
|
|
|
|
Returns:
|
|
- list: A list of directories.
|
|
"""
|
|
self.current_output_dir = path if not path == "" else "."
|
|
return list(list_dirs(path))
|
|
|
|
def list_logging_dirs(self, path: str) -> list:
|
|
"""
|
|
List directories in the logging directory.
|
|
|
|
Parameters:
|
|
- path (str): The path to list directories from.
|
|
|
|
Returns:
|
|
- list: A list of directories.
|
|
"""
|
|
self.current_logging_dir = path if not path == "" else "."
|
|
return list(list_dirs(path))
|
|
|
|
def list_reg_data_dirs(self, path: str) -> list:
|
|
"""
|
|
List directories in the regularization data directory.
|
|
|
|
Parameters:
|
|
- path (str): The path to list directories from.
|
|
|
|
Returns:
|
|
- list: A list of directories.
|
|
"""
|
|
self.current_reg_data_dir = path if not path == "" else "."
|
|
return list(list_dirs(path))
|
|
|
|
def create_folders_gui(self) -> None:
|
|
"""
|
|
Create the GUI for folder selection.
|
|
"""
|
|
with gr.Row():
|
|
# Output directory dropdown
|
|
self.output_dir = gr.Dropdown(
|
|
label=f'Output directory for trained model',
|
|
choices=[""] + self.list_output_dirs(self.current_output_dir),
|
|
value="",
|
|
interactive=True,
|
|
allow_custom_value=True,
|
|
)
|
|
# Refresh button for output directory
|
|
create_refresh_button(self.output_dir, lambda: None, lambda: {"choices": [""] + self.list_output_dirs(self.current_output_dir)}, "open_folder_small")
|
|
# Output directory button
|
|
self.output_dir_folder = gr.Button(
|
|
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
|
|
)
|
|
# Output directory button click event
|
|
self.output_dir_folder.click(
|
|
get_folder_path,
|
|
outputs=self.output_dir,
|
|
show_progress=False,
|
|
)
|
|
|
|
# Regularisation directory dropdown
|
|
self.reg_data_dir = gr.Dropdown(
|
|
label='Regularisation directory (Optional. containing regularisation images)' if not self.finetune else 'Train config directory (Optional. where config files will be saved)',
|
|
choices=[""] + self.list_reg_data_dirs(self.current_reg_data_dir),
|
|
value="",
|
|
interactive=True,
|
|
allow_custom_value=True,
|
|
)
|
|
# Refresh button for regularisation directory
|
|
create_refresh_button(self.reg_data_dir, lambda: None, lambda: {"choices": [""] + self.list_reg_data_dirs(self.current_reg_data_dir)}, "open_folder_small")
|
|
# Regularisation directory button
|
|
self.reg_data_dir_folder = gr.Button(
|
|
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
|
|
)
|
|
# Regularisation directory button click event
|
|
self.reg_data_dir_folder.click(
|
|
get_folder_path,
|
|
outputs=self.reg_data_dir,
|
|
show_progress=False,
|
|
)
|
|
with gr.Row():
|
|
# Logging directory dropdown
|
|
self.logging_dir = gr.Dropdown(
|
|
label='Logging directory (Optional. to enable logging and output Tensorboard log)',
|
|
choices=[""] + self.list_logging_dirs(self.current_logging_dir),
|
|
value="",
|
|
interactive=True,
|
|
allow_custom_value=True,
|
|
)
|
|
# Refresh button for logging directory
|
|
create_refresh_button(self.logging_dir, lambda: None, lambda: {"choices": [""] + self.list_logging_dirs(self.current_logging_dir)}, "open_folder_small")
|
|
# Logging directory button
|
|
self.logging_dir_folder = gr.Button(
|
|
'📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
|
|
)
|
|
# Logging directory button click event
|
|
self.logging_dir_folder.click(
|
|
get_folder_path,
|
|
outputs=self.logging_dir,
|
|
show_progress=False,
|
|
)
|
|
|
|
# Change event for output directory dropdown
|
|
self.output_dir.change(
|
|
fn=lambda path: gr.Dropdown(choices=[""] + self.list_output_dirs(path)),
|
|
inputs=self.output_dir,
|
|
outputs=self.output_dir,
|
|
show_progress=False,
|
|
)
|
|
# Change event for regularisation directory dropdown
|
|
self.reg_data_dir.change(
|
|
fn=lambda path: gr.Dropdown(choices=[""] + self.list_reg_data_dirs(path)),
|
|
inputs=self.reg_data_dir,
|
|
outputs=self.reg_data_dir,
|
|
show_progress=False,
|
|
)
|
|
# Change event for logging directory dropdown
|
|
self.logging_dir.change(
|
|
fn=lambda path: gr.Dropdown(choices=[""] + self.list_logging_dirs(path)),
|
|
inputs=self.logging_dir,
|
|
outputs=self.logging_dir,
|
|
show_progress=False,
|
|
)
|