mirror of https://github.com/bmaltais/kohya_ss
Implement --headless mode
parent
30386a704b
commit
103a9919c7
13
README.md
13
README.md
|
|
@ -331,6 +331,19 @@ This will store a backup file with your current locally installed pip packages a
|
|||
|
||||
## Change History
|
||||
|
||||
* 2023/04/06 (v21.5.9)
|
||||
- Inplement headless mode to enable easier support under headless services like vast.ai. To make use of it start the gui with the `--headless` argument like:
|
||||
|
||||
`.\gui.ps1 --headless` or `.\gui.bat --headless` or `./gui.sh --headless`
|
||||
- Added the option for the user to put the wandb api key in a textbox under the advanced configuration dropdown and a checkbox to toggle for using wandb logging. @x-CK-x
|
||||
- Docker build image @Trojaner
|
||||
- Updated README to use docker compose run instead of docker compose up to fix broken tqdm
|
||||
- Related: Doesn't work with docker-compose tqdm/tqdm#771
|
||||
- Fixed build for latest release
|
||||
- Replace pillow with pillow-simd
|
||||
- Removed --no-cache again as pip cache is not enabled anyway
|
||||
- While overwriting .txt files with prefix and postfix including different encodings you might encounter this decoder error. This small fix gets rid of it... @ertugrul-dmr
|
||||
- Docker Add --no-cache-dir to reduce image size @chiragjn
|
||||
* 2023/04/05 (v21.5.8)
|
||||
- Add `Cache latents to disk` option to the gui.
|
||||
- When saving v2 models in Diffusers format in training scripts and conversion scripts, it was found that the U-Net configuration is different from those of Hugging Face's stabilityai models (this repository is `"use_linear_projection": false`, stabilityai is `true`). Please note that the weight shapes are different, so please be careful when using the weight files directly. We apologize for the inconvenience.
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from library.common_gui import (
|
|||
# set_legacy_8bitadam,
|
||||
update_my_data,
|
||||
check_if_model_exist,
|
||||
output_message,
|
||||
)
|
||||
from library.tensorboard_gui import (
|
||||
gradio_tensorboard,
|
||||
|
|
@ -38,7 +39,7 @@ from library.dreambooth_folder_creation_gui import (
|
|||
)
|
||||
from library.utilities import utilities_tab
|
||||
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
||||
from easygui import msgbox
|
||||
# from easygui import msgbox
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
|
|
@ -267,6 +268,7 @@ def open_configuration(
|
|||
|
||||
|
||||
def train_model(
|
||||
headless,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
|
|
@ -337,39 +339,40 @@ def train_model(
|
|||
use_wandb,
|
||||
wandb_api_key,
|
||||
):
|
||||
headless_bool = True if headless.get('label') == 'True' else False
|
||||
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
output_message(msg='Source model information is missing', headless=headless_bool)
|
||||
return
|
||||
|
||||
if train_data_dir == '':
|
||||
msgbox('Image folder path is missing')
|
||||
output_message(msg='Image folder path is missing', headless=headless_bool)
|
||||
return
|
||||
|
||||
if not os.path.exists(train_data_dir):
|
||||
msgbox('Image folder does not exist')
|
||||
output_message(msg='Image folder does not exist', headless=headless_bool)
|
||||
return
|
||||
|
||||
if reg_data_dir != '':
|
||||
if not os.path.exists(reg_data_dir):
|
||||
msgbox('Regularisation folder does not exist')
|
||||
output_message(msg='Regularisation folder does not exist', headless=headless_bool)
|
||||
return
|
||||
|
||||
if output_dir == '':
|
||||
msgbox('Output folder path is missing')
|
||||
output_message(msg='Output folder path is missing', headless=headless_bool)
|
||||
return
|
||||
|
||||
if check_if_model_exist(output_name, output_dir, save_model_as):
|
||||
if check_if_model_exist(output_name, output_dir, save_model_as, headless=headless_bool):
|
||||
return
|
||||
|
||||
if optimizer == 'Adafactor' and lr_warmup != '0':
|
||||
msgbox(
|
||||
"Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
|
||||
title='Warning',
|
||||
output_message(msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
|
||||
title='Warning', headless=headless_bool
|
||||
)
|
||||
lr_warmup = '0'
|
||||
|
||||
if float(noise_offset) > 0 and (multires_noise_iterations > 0 or multires_noise_discount > 0):
|
||||
msgbox(msg='noise offset and multires_noise can\'t be set at the same time. Only use one or the other.', title='Error')
|
||||
output_message(msg='noise offset and multires_noise can\'t be set at the same time. Only use one or the other.', title='Error', headless=headless_bool)
|
||||
return
|
||||
|
||||
# Get a list of all subfolders in train_data_dir, excluding hidden folders
|
||||
|
|
@ -599,9 +602,11 @@ def dreambooth_tab(
|
|||
reg_data_dir=gr.Textbox(),
|
||||
output_dir=gr.Textbox(),
|
||||
logging_dir=gr.Textbox(),
|
||||
headless=False
|
||||
):
|
||||
dummy_db_true = gr.Label(value=True, visible=False)
|
||||
dummy_db_false = gr.Label(value=False, visible=False)
|
||||
dummy_headless = gr.Label(value=headless, visible=False)
|
||||
gr.Markdown('Train a custom model using kohya dreambooth python code...')
|
||||
(
|
||||
button_open_config,
|
||||
|
|
@ -609,7 +614,7 @@ def dreambooth_tab(
|
|||
button_save_as_config,
|
||||
config_file_name,
|
||||
button_load_config,
|
||||
) = gradio_config()
|
||||
) = gradio_config(headless=headless)
|
||||
|
||||
(
|
||||
pretrained_model_name_or_path,
|
||||
|
|
@ -617,7 +622,7 @@ def dreambooth_tab(
|
|||
v_parameterization,
|
||||
save_model_as,
|
||||
model_list,
|
||||
) = gradio_source_model()
|
||||
) = gradio_source_model(headless=headless)
|
||||
|
||||
with gr.Tab('Folders'):
|
||||
with gr.Row():
|
||||
|
|
@ -626,7 +631,7 @@ def dreambooth_tab(
|
|||
placeholder='Folder where the training folders containing the images are located',
|
||||
)
|
||||
train_data_dir_input_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
'📂', elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
train_data_dir_input_folder.click(
|
||||
get_folder_path,
|
||||
|
|
@ -638,7 +643,7 @@ def dreambooth_tab(
|
|||
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
||||
)
|
||||
reg_data_dir_input_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
'📂', elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
reg_data_dir_input_folder.click(
|
||||
get_folder_path,
|
||||
|
|
@ -651,7 +656,7 @@ def dreambooth_tab(
|
|||
placeholder='Folder to output trained model',
|
||||
)
|
||||
output_dir_input_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
'📂', elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
output_dir_input_folder.click(get_folder_path, outputs=output_dir)
|
||||
logging_dir = gr.Textbox(
|
||||
|
|
@ -659,7 +664,7 @@ def dreambooth_tab(
|
|||
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
||||
)
|
||||
logging_dir_input_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
'📂', elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
logging_dir_input_folder.click(
|
||||
get_folder_path,
|
||||
|
|
@ -748,7 +753,7 @@ def dreambooth_tab(
|
|||
label='VAE',
|
||||
placeholder='(Optiona) path to checkpoint of vae to replace for training',
|
||||
)
|
||||
vae_button = gr.Button('📂', elem_id='open_folder_small')
|
||||
vae_button = gr.Button('📂', elem_id='open_folder_small', visible=(not headless))
|
||||
vae_button.click(
|
||||
get_any_file_path,
|
||||
outputs=vae,
|
||||
|
|
@ -787,7 +792,7 @@ def dreambooth_tab(
|
|||
save_last_n_steps_state,
|
||||
use_wandb,
|
||||
wandb_api_key,
|
||||
) = gradio_advanced_training()
|
||||
) = gradio_advanced_training(headless=headless)
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
inputs=[color_aug],
|
||||
|
|
@ -810,6 +815,7 @@ def dreambooth_tab(
|
|||
reg_data_dir_input=reg_data_dir,
|
||||
output_dir_input=output_dir,
|
||||
logging_dir_input=logging_dir,
|
||||
headless=headless,
|
||||
)
|
||||
|
||||
button_run = gr.Button('Train model', variant='primary')
|
||||
|
|
@ -930,7 +936,7 @@ def dreambooth_tab(
|
|||
|
||||
button_run.click(
|
||||
train_model,
|
||||
inputs=settings_list,
|
||||
inputs=[dummy_headless] + settings_list,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
|
@ -944,13 +950,18 @@ def dreambooth_tab(
|
|||
|
||||
def UI(**kwargs):
|
||||
css = ''
|
||||
|
||||
headless = kwargs.get('headless', False)
|
||||
print(f'headless: {headless}')
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
print('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
|
||||
interface = gr.Blocks(css=css)
|
||||
interface = gr.Blocks(
|
||||
css=css, title='Kohya_ss GUI', theme=gr.themes.Default()
|
||||
)
|
||||
|
||||
with interface:
|
||||
with gr.Tab('Dreambooth'):
|
||||
|
|
@ -959,7 +970,7 @@ def UI(**kwargs):
|
|||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
) = dreambooth_tab()
|
||||
) = dreambooth_tab(headless=headless)
|
||||
with gr.Tab('Utilities'):
|
||||
utilities_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
|
|
@ -967,26 +978,39 @@ def UI(**kwargs):
|
|||
output_dir_input=output_dir_input,
|
||||
logging_dir_input=logging_dir_input,
|
||||
enable_copy_info_button=True,
|
||||
headless=headless
|
||||
)
|
||||
|
||||
# Show the interface
|
||||
launch_kwargs = {}
|
||||
if not kwargs.get('username', None) == '':
|
||||
launch_kwargs['auth'] = (
|
||||
kwargs.get('username', None),
|
||||
kwargs.get('password', None),
|
||||
)
|
||||
if kwargs.get('server_port', 0) > 0:
|
||||
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
||||
if kwargs.get('inbrowser', False):
|
||||
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
|
||||
print(launch_kwargs)
|
||||
username = kwargs.get('username')
|
||||
password = kwargs.get('password')
|
||||
server_port = kwargs.get('server_port', 0)
|
||||
inbrowser = kwargs.get('inbrowser', False)
|
||||
share = kwargs.get('share', False)
|
||||
server_name = kwargs.get('listen')
|
||||
|
||||
launch_kwargs['server_name'] = server_name
|
||||
if username and password:
|
||||
launch_kwargs['auth'] = (username, password)
|
||||
if server_port > 0:
|
||||
launch_kwargs['server_port'] = server_port
|
||||
if inbrowser:
|
||||
launch_kwargs['inbrowser'] = inbrowser
|
||||
if share:
|
||||
launch_kwargs['share'] = share
|
||||
interface.launch(**launch_kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--listen',
|
||||
type=str,
|
||||
default='127.0.0.1',
|
||||
help='IP to listen on for connections to Gradio',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--username', type=str, default='', help='Username for authentication'
|
||||
)
|
||||
|
|
@ -1002,6 +1026,12 @@ if __name__ == '__main__':
|
|||
parser.add_argument(
|
||||
'--inbrowser', action='store_true', help='Open in browser'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--share', action='store_true', help='Share the gradio UI'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--headless', action='store_true', help='Is the server headless'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -1010,4 +1040,7 @@ if __name__ == '__main__':
|
|||
password=args.password,
|
||||
inbrowser=args.inbrowser,
|
||||
server_port=args.server_port,
|
||||
share=args.share,
|
||||
listen=args.listen,
|
||||
headless=args.headless,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from library.common_gui import (
|
|||
# set_legacy_8bitadam,
|
||||
update_my_data,
|
||||
check_if_model_exist,
|
||||
output_message,
|
||||
)
|
||||
from library.tensorboard_gui import (
|
||||
gradio_tensorboard,
|
||||
|
|
@ -29,7 +30,7 @@ from library.tensorboard_gui import (
|
|||
)
|
||||
from library.utilities import utilities_tab
|
||||
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
||||
from easygui import msgbox
|
||||
# from easygui import msgbox
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
|
|
@ -272,6 +273,7 @@ def open_configuration(
|
|||
|
||||
|
||||
def train_model(
|
||||
headless,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
|
|
@ -348,17 +350,19 @@ def train_model(
|
|||
use_wandb,
|
||||
wandb_api_key,
|
||||
):
|
||||
if check_if_model_exist(output_name, output_dir, save_model_as):
|
||||
headless_bool = True if headless.get('label') == 'True' else False
|
||||
|
||||
if check_if_model_exist(output_name, output_dir, save_model_as, headless_bool):
|
||||
return
|
||||
|
||||
if float(noise_offset) > 0 and (multires_noise_iterations > 0 or multires_noise_discount > 0):
|
||||
msgbox(msg='noise offset and multires_noise can\'t be set at the same time. Only use one or the other.', title='Error')
|
||||
output_message(msg='noise offset and multires_noise can\'t be set at the same time. Only use one or the other.', title='Error', headless=headless_bool)
|
||||
return
|
||||
|
||||
if optimizer == 'Adafactor' and lr_warmup != '0':
|
||||
msgbox(
|
||||
"Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
|
||||
title='Warning',
|
||||
output_message(
|
||||
msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
|
||||
title='Warning', headless=headless_bool
|
||||
)
|
||||
lr_warmup = '0'
|
||||
|
||||
|
|
@ -562,9 +566,10 @@ def remove_doublequote(file_path):
|
|||
return file_path
|
||||
|
||||
|
||||
def finetune_tab():
|
||||
def finetune_tab(headless=False):
|
||||
dummy_db_true = gr.Label(value=True, visible=False)
|
||||
dummy_db_false = gr.Label(value=False, visible=False)
|
||||
dummy_headless = gr.Label(value=headless, visible=False)
|
||||
gr.Markdown('Train a custom model using kohya finetune python code...')
|
||||
|
||||
(
|
||||
|
|
@ -573,7 +578,7 @@ def finetune_tab():
|
|||
button_save_as_config,
|
||||
config_file_name,
|
||||
button_load_config,
|
||||
) = gradio_config()
|
||||
) = gradio_config(headless=headless)
|
||||
|
||||
(
|
||||
pretrained_model_name_or_path,
|
||||
|
|
@ -581,7 +586,7 @@ def finetune_tab():
|
|||
v_parameterization,
|
||||
save_model_as,
|
||||
model_list,
|
||||
) = gradio_source_model()
|
||||
) = gradio_source_model(headless=headless)
|
||||
|
||||
with gr.Tab('Folders'):
|
||||
with gr.Row():
|
||||
|
|
@ -590,7 +595,7 @@ def finetune_tab():
|
|||
placeholder='folder where the training configuration files will be saved',
|
||||
)
|
||||
train_dir_folder = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
train_dir_folder.click(
|
||||
get_folder_path,
|
||||
|
|
@ -603,7 +608,7 @@ def finetune_tab():
|
|||
placeholder='folder where the training images are located',
|
||||
)
|
||||
image_folder_input_folder = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
image_folder_input_folder.click(
|
||||
get_folder_path,
|
||||
|
|
@ -616,7 +621,7 @@ def finetune_tab():
|
|||
placeholder='folder where the model will be saved',
|
||||
)
|
||||
output_dir_input_folder = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
output_dir_input_folder.click(
|
||||
get_folder_path,
|
||||
|
|
@ -629,7 +634,7 @@ def finetune_tab():
|
|||
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
||||
)
|
||||
logging_dir_input_folder = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
logging_dir_input_folder.click(
|
||||
get_folder_path,
|
||||
|
|
@ -759,7 +764,7 @@ def finetune_tab():
|
|||
save_last_n_steps_state,
|
||||
use_wandb,
|
||||
wandb_api_key,
|
||||
) = gradio_advanced_training()
|
||||
) = gradio_advanced_training(headless=headless)
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
inputs=[color_aug],
|
||||
|
|
@ -866,7 +871,7 @@ def finetune_tab():
|
|||
wandb_api_key,
|
||||
]
|
||||
|
||||
button_run.click(train_model, inputs=settings_list)
|
||||
button_run.click(train_model, inputs=[dummy_headless] + settings_list)
|
||||
|
||||
button_open_config.click(
|
||||
open_configuration,
|
||||
|
|
@ -898,40 +903,56 @@ def finetune_tab():
|
|||
|
||||
|
||||
def UI(**kwargs):
|
||||
|
||||
css = ''
|
||||
|
||||
headless = kwargs.get('headless', False)
|
||||
print(f'headless: {headless}')
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
print('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
|
||||
interface = gr.Blocks(css=css)
|
||||
interface = gr.Blocks(
|
||||
css=css, title='Kohya_ss GUI', theme=gr.themes.Default()
|
||||
)
|
||||
|
||||
with interface:
|
||||
with gr.Tab('Finetune'):
|
||||
finetune_tab()
|
||||
finetune_tab(headless=headless)
|
||||
with gr.Tab('Utilities'):
|
||||
utilities_tab(enable_dreambooth_tab=False)
|
||||
utilities_tab(enable_dreambooth_tab=False, headless=headless)
|
||||
|
||||
# Show the interface
|
||||
launch_kwargs = {}
|
||||
if not kwargs.get('username', None) == '':
|
||||
launch_kwargs['auth'] = (
|
||||
kwargs.get('username', None),
|
||||
kwargs.get('password', None),
|
||||
)
|
||||
if kwargs.get('server_port', 0) > 0:
|
||||
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
||||
if kwargs.get('inbrowser', False):
|
||||
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
|
||||
print(launch_kwargs)
|
||||
username = kwargs.get('username')
|
||||
password = kwargs.get('password')
|
||||
server_port = kwargs.get('server_port', 0)
|
||||
inbrowser = kwargs.get('inbrowser', False)
|
||||
share = kwargs.get('share', False)
|
||||
server_name = kwargs.get('listen')
|
||||
|
||||
launch_kwargs['server_name'] = server_name
|
||||
if username and password:
|
||||
launch_kwargs['auth'] = (username, password)
|
||||
if server_port > 0:
|
||||
launch_kwargs['server_port'] = server_port
|
||||
if inbrowser:
|
||||
launch_kwargs['inbrowser'] = inbrowser
|
||||
if share:
|
||||
launch_kwargs['share'] = share
|
||||
interface.launch(**launch_kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--listen',
|
||||
type=str,
|
||||
default='127.0.0.1',
|
||||
help='IP to listen on for connections to Gradio',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--username', type=str, default='', help='Username for authentication'
|
||||
)
|
||||
|
|
@ -947,6 +968,12 @@ if __name__ == '__main__':
|
|||
parser.add_argument(
|
||||
'--inbrowser', action='store_true', help='Open in browser'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--share', action='store_true', help='Share the gradio UI'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--headless', action='store_true', help='Is the server headless'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -955,4 +982,7 @@ if __name__ == '__main__':
|
|||
password=args.password,
|
||||
inbrowser=args.inbrowser,
|
||||
server_port=args.server_port,
|
||||
share=args.share,
|
||||
listen=args.listen,
|
||||
headless=args.headless,
|
||||
)
|
||||
|
|
|
|||
28
kohya_gui.py
28
kohya_gui.py
|
|
@ -82,6 +82,9 @@ def setup_logging(clean=False):
|
|||
|
||||
def UI(**kwargs):
|
||||
css = ''
|
||||
|
||||
headless = kwargs.get('headless', False)
|
||||
print(f'headless: {headless}')
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
|
|
@ -99,13 +102,13 @@ def UI(**kwargs):
|
|||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
) = dreambooth_tab()
|
||||
) = dreambooth_tab(headless=headless)
|
||||
with gr.Tab('Dreambooth LoRA'):
|
||||
lora_tab()
|
||||
lora_tab(headless=headless)
|
||||
with gr.Tab('Dreambooth TI'):
|
||||
ti_tab()
|
||||
ti_tab(headless=headless)
|
||||
with gr.Tab('Finetune'):
|
||||
finetune_tab()
|
||||
finetune_tab(headless=headless)
|
||||
with gr.Tab('Utilities'):
|
||||
utilities_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
|
|
@ -113,13 +116,14 @@ def UI(**kwargs):
|
|||
output_dir_input=output_dir_input,
|
||||
logging_dir_input=logging_dir_input,
|
||||
enable_copy_info_button=True,
|
||||
headless=headless
|
||||
)
|
||||
gradio_extract_dylora_tab()
|
||||
gradio_extract_lora_tab()
|
||||
gradio_extract_lycoris_locon_tab()
|
||||
gradio_merge_lora_tab()
|
||||
gradio_merge_lycoris_tab()
|
||||
gradio_resize_lora_tab()
|
||||
gradio_extract_dylora_tab(headless=headless)
|
||||
gradio_extract_lora_tab(headless=headless)
|
||||
gradio_extract_lycoris_locon_tab(headless=headless)
|
||||
gradio_merge_lora_tab(headless=headless)
|
||||
gradio_merge_lycoris_tab(headless=headless)
|
||||
gradio_resize_lora_tab(headless=headless)
|
||||
|
||||
# Show the interface
|
||||
launch_kwargs = {}
|
||||
|
|
@ -169,6 +173,9 @@ if __name__ == '__main__':
|
|||
parser.add_argument(
|
||||
'--share', action='store_true', help='Share the gradio UI'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--headless', action='store_true', help='Is the server headless'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -179,4 +186,5 @@ if __name__ == '__main__':
|
|||
server_port=args.server_port,
|
||||
share=args.share,
|
||||
listen=args.listen,
|
||||
headless=args.headless,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ def caption_images(
|
|||
|
||||
|
||||
# Gradio UI
|
||||
def gradio_basic_caption_gui_tab():
|
||||
def gradio_basic_caption_gui_tab(headless=False):
|
||||
with gr.Tab('Basic Captioning'):
|
||||
gr.Markdown(
|
||||
'This utility will allow the creation of simple caption files for each image in a folder.'
|
||||
|
|
@ -79,7 +79,7 @@ def gradio_basic_caption_gui_tab():
|
|||
placeholder='Directory containing the images to caption',
|
||||
interactive=True,
|
||||
)
|
||||
folder_button = gr.Button('📂', elem_id='open_folder_small')
|
||||
folder_button = gr.Button('📂', elem_id='open_folder_small', visible=(not headless))
|
||||
folder_button.click(
|
||||
get_folder_path,
|
||||
outputs=images_dir,
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ def caption_images(
|
|||
###
|
||||
|
||||
|
||||
def gradio_blip_caption_gui_tab():
|
||||
def gradio_blip_caption_gui_tab(headless=False):
|
||||
with gr.Tab('BLIP Captioning'):
|
||||
gr.Markdown(
|
||||
'This utility will use BLIP to caption files for each images in a folder.'
|
||||
|
|
@ -83,7 +83,7 @@ def gradio_blip_caption_gui_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_train_data_dir_input = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
'📂', elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_train_data_dir_input.click(
|
||||
get_folder_path,
|
||||
|
|
|
|||
|
|
@ -35,7 +35,11 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
|
|||
ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID']
|
||||
|
||||
|
||||
def check_if_model_exist(output_name, output_dir, save_model_as):
|
||||
def check_if_model_exist(output_name, output_dir, save_model_as, headless=False):
|
||||
if headless:
|
||||
print('Headless mode, skipping verification if model already exist... if model already exist it will be overwritten...')
|
||||
return False
|
||||
|
||||
if save_model_as in ['diffusers', 'diffusers_safetendors']:
|
||||
ckpt_folder = os.path.join(output_dir, output_name)
|
||||
if os.path.isdir(ckpt_folder):
|
||||
|
|
@ -62,6 +66,11 @@ def check_if_model_exist(output_name, output_dir, save_model_as):
|
|||
|
||||
return False
|
||||
|
||||
def output_message(msg='', title='', headless=False):
|
||||
if headless:
|
||||
print(msg)
|
||||
else:
|
||||
msgbox(msg=msg, title=title)
|
||||
|
||||
def update_my_data(my_data):
|
||||
# Update the optimizer based on the use_8bit_adam flag
|
||||
|
|
@ -529,13 +538,13 @@ def set_model_list(
|
|||
###
|
||||
|
||||
|
||||
def gradio_config():
|
||||
def gradio_config(headless=False):
|
||||
with gr.Accordion('Configuration file', open=False):
|
||||
with gr.Row():
|
||||
button_open_config = gr.Button('Open 📂', elem_id='open_folder')
|
||||
button_save_config = gr.Button('Save 💾', elem_id='open_folder')
|
||||
button_open_config = gr.Button('Open 📂', elem_id='open_folder', visible=(not headless))
|
||||
button_save_config = gr.Button('Save 💾', elem_id='open_folder', visible=(not headless))
|
||||
button_save_as_config = gr.Button(
|
||||
'Save as... 💾', elem_id='open_folder'
|
||||
'Save as... 💾', elem_id='open_folder', visible=(not headless)
|
||||
)
|
||||
config_file_name = gr.Textbox(
|
||||
label='',
|
||||
|
|
@ -573,7 +582,8 @@ def gradio_source_model(
|
|||
'diffusers',
|
||||
'diffusers_safetensors',
|
||||
'safetensors',
|
||||
]
|
||||
],
|
||||
headless=False
|
||||
):
|
||||
with gr.Tab('Source model'):
|
||||
# Define the input elements
|
||||
|
|
@ -584,7 +594,7 @@ def gradio_source_model(
|
|||
value='runwayml/stable-diffusion-v1-5',
|
||||
)
|
||||
pretrained_model_name_or_path_file = gr.Button(
|
||||
document_symbol, elem_id='open_folder_small'
|
||||
document_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
pretrained_model_name_or_path_file.click(
|
||||
get_any_file_path,
|
||||
|
|
@ -593,7 +603,7 @@ def gradio_source_model(
|
|||
show_progress=False,
|
||||
)
|
||||
pretrained_model_name_or_path_folder = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
pretrained_model_name_or_path_folder.click(
|
||||
get_folder_path,
|
||||
|
|
@ -835,7 +845,7 @@ def run_cmd_training(**kwargs):
|
|||
return run_cmd
|
||||
|
||||
|
||||
def gradio_advanced_training():
|
||||
def gradio_advanced_training(headless=False):
|
||||
with gr.Row():
|
||||
additional_parameters = gr.Textbox(
|
||||
label='Additional parameters',
|
||||
|
|
@ -937,7 +947,7 @@ def gradio_advanced_training():
|
|||
label='Resume from saved training state',
|
||||
placeholder='path to "last-state" state folder to resume from',
|
||||
)
|
||||
resume_button = gr.Button('📂', elem_id='open_folder_small')
|
||||
resume_button = gr.Button('📂', elem_id='open_folder_small', visible=(not headless))
|
||||
resume_button.click(
|
||||
get_folder_path,
|
||||
outputs=resume,
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ def convert_model(
|
|||
###
|
||||
|
||||
|
||||
def gradio_convert_model_tab():
|
||||
def gradio_convert_model_tab(headless=False):
|
||||
with gr.Tab('Convert model'):
|
||||
gr.Markdown(
|
||||
'This utility can be used to convert from one stable diffusion model format to another.'
|
||||
|
|
@ -176,7 +176,7 @@ def gradio_convert_model_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_source_model_dir = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_source_model_dir.click(
|
||||
get_folder_path,
|
||||
|
|
@ -185,7 +185,7 @@ def gradio_convert_model_tab():
|
|||
)
|
||||
|
||||
button_source_model_file = gr.Button(
|
||||
document_symbol, elem_id='open_folder_small'
|
||||
document_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_source_model_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -212,7 +212,7 @@ def gradio_convert_model_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_target_model_folder = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_target_model_folder.click(
|
||||
get_folder_path,
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ def warning(insecure):
|
|||
return False
|
||||
|
||||
|
||||
def gradio_dataset_balancing_tab():
|
||||
def gradio_dataset_balancing_tab(headless=False):
|
||||
with gr.Tab('Dreambooth/LoRA Dataset balancing'):
|
||||
gr.Markdown(
|
||||
'This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training.'
|
||||
|
|
@ -120,7 +120,7 @@ def gradio_dataset_balancing_tab():
|
|||
)
|
||||
|
||||
select_dataset_folder_button = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
'📂', elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
select_dataset_folder_button.click(
|
||||
get_folder_path,
|
||||
|
|
|
|||
|
|
@ -114,6 +114,7 @@ def gradio_dreambooth_folder_creation_tab(
|
|||
reg_data_dir_input=gr.Textbox(),
|
||||
output_dir_input=gr.Textbox(),
|
||||
logging_dir_input=gr.Textbox(),
|
||||
headless=False
|
||||
):
|
||||
with gr.Tab('Dreambooth/LoRA Folder preparation'):
|
||||
gr.Markdown(
|
||||
|
|
@ -137,7 +138,7 @@ def gradio_dreambooth_folder_creation_tab(
|
|||
interactive=True,
|
||||
)
|
||||
button_util_training_images_dir_input = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
'📂', elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_util_training_images_dir_input.click(
|
||||
get_folder_path,
|
||||
|
|
@ -157,7 +158,7 @@ def gradio_dreambooth_folder_creation_tab(
|
|||
interactive=True,
|
||||
)
|
||||
button_util_regularization_images_dir_input = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
'📂', elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_util_regularization_images_dir_input.click(
|
||||
get_folder_path,
|
||||
|
|
@ -177,7 +178,7 @@ def gradio_dreambooth_folder_creation_tab(
|
|||
interactive=True,
|
||||
)
|
||||
button_util_training_dir_output = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
'📂', elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_util_training_dir_output.click(
|
||||
get_folder_path, outputs=util_training_dir_output
|
||||
|
|
|
|||
|
|
@ -52,10 +52,10 @@ def extract_dylora(
|
|||
###
|
||||
|
||||
|
||||
def gradio_extract_dylora_tab():
|
||||
def gradio_extract_dylora_tab(headless=False):
|
||||
with gr.Tab('Extract DyLoRA'):
|
||||
gr.Markdown(
|
||||
'This utility can extract a LoRA network from a finetuned model.'
|
||||
'This utility can extract a DyLoRA network from a finetuned model.'
|
||||
)
|
||||
lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
|
||||
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
|
||||
|
|
@ -67,7 +67,7 @@ def gradio_extract_dylora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_model_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -82,7 +82,7 @@ def gradio_extract_dylora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_save_to = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_save_to.click(
|
||||
get_saveasfilename_path,
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ def extract_lora(
|
|||
###
|
||||
|
||||
|
||||
def gradio_extract_lora_tab():
|
||||
def gradio_extract_lora_tab(headless=False):
|
||||
with gr.Tab('Extract LoRA'):
|
||||
gr.Markdown(
|
||||
'This utility can extract a LoRA network from a finetuned model.'
|
||||
|
|
@ -88,7 +88,7 @@ def gradio_extract_lora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_model_tuned_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_model_tuned_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -103,7 +103,7 @@ def gradio_extract_lora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_model_org_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_model_org_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -118,7 +118,7 @@ def gradio_extract_lora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_save_to = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_save_to.click(
|
||||
get_saveasfilename_path,
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ def update_mode(mode):
|
|||
return tuple(updates)
|
||||
|
||||
|
||||
def gradio_extract_lycoris_locon_tab():
|
||||
def gradio_extract_lycoris_locon_tab(headless=False):
|
||||
with gr.Tab('Extract LyCORIS LoCON'):
|
||||
gr.Markdown(
|
||||
'This utility can extract a LyCORIS LoCon network from a finetuned model.'
|
||||
|
|
@ -138,7 +138,7 @@ def gradio_extract_lycoris_locon_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_db_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_db_model_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -153,7 +153,7 @@ def gradio_extract_lycoris_locon_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_base_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_base_model_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -168,7 +168,7 @@ def gradio_extract_lycoris_locon_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_output_name = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_output_name.click(
|
||||
get_saveasfilename_path,
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ def caption_images(
|
|||
###
|
||||
|
||||
|
||||
def gradio_git_caption_gui_tab():
|
||||
def gradio_git_caption_gui_tab(headless=False):
|
||||
with gr.Tab('GIT Captioning'):
|
||||
gr.Markdown(
|
||||
'This utility will use GIT to caption files for each images in a folder.'
|
||||
|
|
@ -75,7 +75,7 @@ def gradio_git_caption_gui_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_train_data_dir_input = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
'📂', elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_train_data_dir_input.click(
|
||||
get_folder_path,
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ def merge_lora(
|
|||
###
|
||||
|
||||
|
||||
def gradio_merge_lora_tab():
|
||||
def gradio_merge_lora_tab(headless=False):
|
||||
with gr.Tab('Merge LoRA'):
|
||||
gr.Markdown(
|
||||
'This utility can merge up to 4 LoRA together or alternativelly merge up to 4 LoRA into a SD checkpoint.'
|
||||
|
|
@ -115,7 +115,7 @@ def gradio_merge_lora_tab():
|
|||
info='Provide a SD file path IF you want to merge it with LoRA files',
|
||||
)
|
||||
sd_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
sd_model_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -131,7 +131,7 @@ def gradio_merge_lora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_lora_a_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_lora_a_model_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -146,7 +146,7 @@ def gradio_merge_lora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_lora_b_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_lora_b_model_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -181,7 +181,7 @@ def gradio_merge_lora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_lora_c_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_lora_c_model_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -196,7 +196,7 @@ def gradio_merge_lora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_lora_d_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_lora_d_model_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -231,7 +231,7 @@ def gradio_merge_lora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_save_to = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_save_to.click(
|
||||
get_saveasfilename_path,
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ def merge_lycoris(
|
|||
###
|
||||
|
||||
|
||||
def gradio_merge_lycoris_tab():
|
||||
def gradio_merge_lycoris_tab(headless=False):
|
||||
with gr.Tab('Merge LyCORIS'):
|
||||
gr.Markdown(
|
||||
'This utility can merge a LyCORIS model into a SD checkpoint.'
|
||||
|
|
@ -70,7 +70,7 @@ def gradio_merge_lycoris_tab():
|
|||
info='Provide a SD file path that you want to merge with the LyCORIS file',
|
||||
)
|
||||
base_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
base_model_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -86,7 +86,7 @@ def gradio_merge_lycoris_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_lycoris_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_lycoris_model_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -112,7 +112,7 @@ def gradio_merge_lycoris_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_output_name = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_output_name.click(
|
||||
get_saveasfilename_path,
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ def resize_lora(
|
|||
###
|
||||
|
||||
|
||||
def gradio_resize_lora_tab():
|
||||
def gradio_resize_lora_tab(headless=False):
|
||||
with gr.Tab('Resize LoRA'):
|
||||
gr.Markdown('This utility can resize a LoRA.')
|
||||
|
||||
|
|
@ -94,7 +94,7 @@ def gradio_resize_lora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_lora_a_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_lora_a_model_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -133,7 +133,7 @@ def gradio_resize_lora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_save_to = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_save_to.click(
|
||||
get_saveasfilename_path,
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ def svd_merge_lora(
|
|||
###
|
||||
|
||||
|
||||
def gradio_svd_merge_lora_tab():
|
||||
def gradio_svd_merge_lora_tab(headless=False):
|
||||
with gr.Tab('Merge LoRA (SVD)'):
|
||||
gr.Markdown('This utility can merge two LoRA networks together.')
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ def gradio_svd_merge_lora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_lora_a_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_lora_a_model_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -100,7 +100,7 @@ def gradio_svd_merge_lora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_lora_b_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_lora_b_model_file.click(
|
||||
get_file_path,
|
||||
|
|
@ -141,7 +141,7 @@ def gradio_svd_merge_lora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_save_to = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_save_to.click(
|
||||
get_saveasfilename_path,
|
||||
|
|
|
|||
|
|
@ -20,13 +20,14 @@ def utilities_tab(
|
|||
logging_dir_input=gr.Textbox(),
|
||||
enable_copy_info_button=bool(False),
|
||||
enable_dreambooth_tab=True,
|
||||
headless=False
|
||||
):
|
||||
with gr.Tab('Captioning'):
|
||||
gradio_basic_caption_gui_tab()
|
||||
gradio_blip_caption_gui_tab()
|
||||
gradio_git_caption_gui_tab()
|
||||
gradio_wd14_caption_gui_tab()
|
||||
gradio_convert_model_tab()
|
||||
gradio_basic_caption_gui_tab(headless=headless)
|
||||
gradio_blip_caption_gui_tab(headless=headless)
|
||||
gradio_git_caption_gui_tab(headless=headless)
|
||||
gradio_wd14_caption_gui_tab(headless=headless)
|
||||
gradio_convert_model_tab(headless=headless)
|
||||
|
||||
return (
|
||||
train_data_dir_input,
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ def verify_lora(
|
|||
###
|
||||
|
||||
|
||||
def gradio_verify_lora_tab():
|
||||
def gradio_verify_lora_tab(headless=False):
|
||||
with gr.Tab('Verify LoRA'):
|
||||
gr.Markdown(
|
||||
'This utility can verify a LoRA network to make sure it is properly trained.'
|
||||
|
|
@ -66,7 +66,7 @@ def gradio_verify_lora_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_lora_model_file = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
folder_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_lora_model_file.click(
|
||||
get_file_path,
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ def caption_images(
|
|||
###
|
||||
|
||||
|
||||
def gradio_wd14_caption_gui_tab():
|
||||
def gradio_wd14_caption_gui_tab(headless=False):
|
||||
with gr.Tab('WD14 Captioning'):
|
||||
gr.Markdown(
|
||||
'This utility will use WD14 to caption files for each images in a folder.'
|
||||
|
|
@ -83,7 +83,7 @@ def gradio_wd14_caption_gui_tab():
|
|||
interactive=True,
|
||||
)
|
||||
button_train_data_dir_input = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
'📂', elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
button_train_data_dir_input.click(
|
||||
get_folder_path,
|
||||
|
|
|
|||
126
lora_gui.py
126
lora_gui.py
|
|
@ -4,7 +4,7 @@
|
|||
# v3.1: Adding captionning of images to utilities
|
||||
|
||||
import gradio as gr
|
||||
import easygui
|
||||
# import easygui
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
|
@ -28,6 +28,7 @@ from library.common_gui import (
|
|||
# set_legacy_8bitadam,
|
||||
update_my_data,
|
||||
check_if_model_exist,
|
||||
output_message,
|
||||
)
|
||||
from library.dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
|
|
@ -44,7 +45,7 @@ from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab
|
|||
from library.verify_lora_gui import gradio_verify_lora_tab
|
||||
from library.resize_lora_gui import gradio_resize_lora_tab
|
||||
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
||||
from easygui import msgbox
|
||||
# from easygui import msgbox
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
|
|
@ -317,7 +318,9 @@ def open_configuration(
|
|||
return tuple(values)
|
||||
|
||||
|
||||
|
||||
def train_model(
|
||||
headless,
|
||||
print_only,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
|
|
@ -408,59 +411,60 @@ def train_model(
|
|||
wandb_api_key,
|
||||
):
|
||||
print_only_bool = True if print_only.get('label') == 'True' else False
|
||||
headless_bool = True if headless.get('label') == 'True' else False
|
||||
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
output_message(msg='Source model information is missing', headless=headless_bool)
|
||||
return
|
||||
|
||||
if train_data_dir == '':
|
||||
msgbox('Image folder path is missing')
|
||||
output_message(msg='Image folder path is missing', headless=headless_bool)
|
||||
return
|
||||
|
||||
if not os.path.exists(train_data_dir):
|
||||
msgbox('Image folder does not exist')
|
||||
output_message(msg='Image folder does not exist', headless=headless_bool)
|
||||
return
|
||||
|
||||
if reg_data_dir != '':
|
||||
if not os.path.exists(reg_data_dir):
|
||||
msgbox('Regularisation folder does not exist')
|
||||
output_message(msg='Regularisation folder does not exist', headless=headless_bool)
|
||||
return
|
||||
|
||||
if output_dir == '':
|
||||
msgbox('Output folder path is missing')
|
||||
output_message(msg='Output folder path is missing', headless=headless_bool)
|
||||
return
|
||||
|
||||
if int(bucket_reso_steps) < 1:
|
||||
msgbox('Bucket resolution steps need to be greater than 0')
|
||||
output_message(msg='Bucket resolution steps need to be greater than 0', headless=headless_bool)
|
||||
return
|
||||
|
||||
if noise_offset == '':
|
||||
noise_offset = 0
|
||||
|
||||
if float(noise_offset) > 1 or float(noise_offset) < 0:
|
||||
msgbox('Noise offset need to be a value between 0 and 1')
|
||||
output_message(msg='Noise offset need to be a value between 0 and 1', headless=headless_bool)
|
||||
return
|
||||
|
||||
if float(noise_offset) > 0 and (multires_noise_iterations > 0 or multires_noise_discount > 0):
|
||||
msgbox(msg='noise offset and multires_noise can\'t be set at the same time. Only use one or the other.', title='Error')
|
||||
output_message(msg='noise offset and multires_noise can\'t be set at the same time. Only use one or the other.', title='Error', headless=headless_bool)
|
||||
return
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
if stop_text_encoder_training_pct > 0:
|
||||
msgbox(
|
||||
'Output "stop text encoder training" is not yet supported. Ignoring'
|
||||
output_message(
|
||||
msg='Output "stop text encoder training" is not yet supported. Ignoring', headless=headless_bool
|
||||
)
|
||||
stop_text_encoder_training_pct = 0
|
||||
|
||||
if check_if_model_exist(output_name, output_dir, save_model_as):
|
||||
if check_if_model_exist(output_name, output_dir, save_model_as, headless=headless_bool):
|
||||
return
|
||||
|
||||
if optimizer == 'Adafactor' and lr_warmup != '0':
|
||||
msgbox(
|
||||
"Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
|
||||
title='Warning',
|
||||
output_message(
|
||||
msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
|
||||
title='Warning', headless=headless_bool
|
||||
)
|
||||
lr_warmup = '0'
|
||||
|
||||
|
|
@ -673,7 +677,7 @@ def train_model(
|
|||
run_cmd += f' --network_train_unet_only'
|
||||
else:
|
||||
if float(learning_rate) == 0:
|
||||
msgbox('Please input learning rate values.')
|
||||
output_message(msg='Please input learning rate values.', headless=headless_bool)
|
||||
return
|
||||
|
||||
run_cmd += f' --network_dim={network_dim}'
|
||||
|
|
@ -796,9 +800,12 @@ def lora_tab(
|
|||
reg_data_dir_input=gr.Textbox(),
|
||||
output_dir_input=gr.Textbox(),
|
||||
logging_dir_input=gr.Textbox(),
|
||||
headless=False
|
||||
):
|
||||
dummy_db_true = gr.Label(value=True, visible=False)
|
||||
dummy_db_false = gr.Label(value=False, visible=False)
|
||||
dummy_headless = gr.Label(value=headless, visible=False)
|
||||
|
||||
gr.Markdown(
|
||||
'Train a custom model using kohya train network LoRA python code...'
|
||||
)
|
||||
|
|
@ -808,7 +815,7 @@ def lora_tab(
|
|||
button_save_as_config,
|
||||
config_file_name,
|
||||
button_load_config,
|
||||
) = gradio_config()
|
||||
) = gradio_config(headless=headless)
|
||||
|
||||
(
|
||||
pretrained_model_name_or_path,
|
||||
|
|
@ -820,7 +827,7 @@ def lora_tab(
|
|||
save_model_as_choices=[
|
||||
'ckpt',
|
||||
'safetensors',
|
||||
]
|
||||
], headless=headless
|
||||
)
|
||||
|
||||
with gr.Tab('Folders'):
|
||||
|
|
@ -829,7 +836,7 @@ def lora_tab(
|
|||
label='Image folder',
|
||||
placeholder='Folder where the training folders containing the images are located',
|
||||
)
|
||||
train_data_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
||||
train_data_dir_folder = gr.Button('📂', elem_id='open_folder_small', visible=(not headless))
|
||||
train_data_dir_folder.click(
|
||||
get_folder_path,
|
||||
outputs=train_data_dir,
|
||||
|
|
@ -839,7 +846,7 @@ def lora_tab(
|
|||
label='Regularisation folder',
|
||||
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
||||
)
|
||||
reg_data_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
||||
reg_data_dir_folder = gr.Button('📂', elem_id='open_folder_small', visible=(not headless))
|
||||
reg_data_dir_folder.click(
|
||||
get_folder_path,
|
||||
outputs=reg_data_dir,
|
||||
|
|
@ -850,7 +857,7 @@ def lora_tab(
|
|||
label='Output folder',
|
||||
placeholder='Folder to output trained model',
|
||||
)
|
||||
output_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
||||
output_dir_folder = gr.Button('📂', elem_id='open_folder_small', visible=(not headless))
|
||||
output_dir_folder.click(
|
||||
get_folder_path,
|
||||
outputs=output_dir,
|
||||
|
|
@ -860,7 +867,7 @@ def lora_tab(
|
|||
label='Logging folder',
|
||||
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
||||
)
|
||||
logging_dir_folder = gr.Button('📂', elem_id='open_folder_small')
|
||||
logging_dir_folder = gr.Button('📂', elem_id='open_folder_small', visible=(not headless))
|
||||
logging_dir_folder.click(
|
||||
get_folder_path,
|
||||
outputs=logging_dir,
|
||||
|
|
@ -917,7 +924,7 @@ def lora_tab(
|
|||
placeholder='{Optional) Path to existing LoRA network weights to resume training',
|
||||
)
|
||||
lora_network_weights_file = gr.Button(
|
||||
document_symbol, elem_id='open_folder_small'
|
||||
document_symbol, elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
lora_network_weights_file.click(
|
||||
get_any_file_path,
|
||||
|
|
@ -1159,7 +1166,7 @@ def lora_tab(
|
|||
save_last_n_steps_state,
|
||||
use_wandb,
|
||||
wandb_api_key,
|
||||
) = gradio_advanced_training()
|
||||
) = gradio_advanced_training(headless=headless)
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
inputs=[color_aug],
|
||||
|
|
@ -1188,12 +1195,13 @@ def lora_tab(
|
|||
reg_data_dir_input=reg_data_dir,
|
||||
output_dir_input=output_dir,
|
||||
logging_dir_input=logging_dir,
|
||||
headless=headless,
|
||||
)
|
||||
gradio_dataset_balancing_tab()
|
||||
gradio_merge_lora_tab()
|
||||
gradio_svd_merge_lora_tab()
|
||||
gradio_resize_lora_tab()
|
||||
gradio_verify_lora_tab()
|
||||
gradio_dataset_balancing_tab(headless=headless)
|
||||
gradio_merge_lora_tab(headless=headless)
|
||||
gradio_svd_merge_lora_tab(headless=headless)
|
||||
gradio_resize_lora_tab(headless=headless)
|
||||
gradio_verify_lora_tab(headless=headless)
|
||||
|
||||
button_run = gr.Button('Train model', variant='primary')
|
||||
|
||||
|
|
@ -1333,13 +1341,13 @@ def lora_tab(
|
|||
|
||||
button_run.click(
|
||||
train_model,
|
||||
inputs=[dummy_db_false] + settings_list,
|
||||
inputs=[dummy_headless] + [dummy_db_false] + settings_list,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
button_print.click(
|
||||
train_model,
|
||||
inputs=[dummy_db_true] + settings_list,
|
||||
inputs=[dummy_headless] + [dummy_db_true] + settings_list,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
|
@ -1353,13 +1361,18 @@ def lora_tab(
|
|||
|
||||
def UI(**kwargs):
|
||||
css = ''
|
||||
|
||||
headless = kwargs.get('headless', False)
|
||||
print(f'headless: {headless}')
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
print('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
|
||||
interface = gr.Blocks(css=css)
|
||||
interface = gr.Blocks(
|
||||
css=css, title='Kohya_ss GUI', theme=gr.themes.Default()
|
||||
)
|
||||
|
||||
with interface:
|
||||
with gr.Tab('LoRA'):
|
||||
|
|
@ -1368,7 +1381,7 @@ def UI(**kwargs):
|
|||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
) = lora_tab()
|
||||
) = lora_tab(headless=headless)
|
||||
with gr.Tab('Utilities'):
|
||||
utilities_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
|
|
@ -1376,28 +1389,39 @@ def UI(**kwargs):
|
|||
output_dir_input=output_dir_input,
|
||||
logging_dir_input=logging_dir_input,
|
||||
enable_copy_info_button=True,
|
||||
headless=headless
|
||||
)
|
||||
|
||||
# Show the interface
|
||||
launch_kwargs = {}
|
||||
if not kwargs.get('username', None) == '':
|
||||
launch_kwargs['auth'] = (
|
||||
kwargs.get('username', None),
|
||||
kwargs.get('password', None),
|
||||
)
|
||||
if kwargs.get('server_port', 0) > 0:
|
||||
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
||||
if kwargs.get('inbrowser', False):
|
||||
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
|
||||
if kwargs.get('listen', True):
|
||||
launch_kwargs['server_name'] = '0.0.0.0'
|
||||
print(launch_kwargs)
|
||||
username = kwargs.get('username')
|
||||
password = kwargs.get('password')
|
||||
server_port = kwargs.get('server_port', 0)
|
||||
inbrowser = kwargs.get('inbrowser', False)
|
||||
share = kwargs.get('share', False)
|
||||
server_name = kwargs.get('listen')
|
||||
|
||||
launch_kwargs['server_name'] = server_name
|
||||
if username and password:
|
||||
launch_kwargs['auth'] = (username, password)
|
||||
if server_port > 0:
|
||||
launch_kwargs['server_port'] = server_port
|
||||
if inbrowser:
|
||||
launch_kwargs['inbrowser'] = inbrowser
|
||||
if share:
|
||||
launch_kwargs['share'] = share
|
||||
interface.launch(**launch_kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--listen',
|
||||
type=str,
|
||||
default='127.0.0.1',
|
||||
help='IP to listen on for connections to Gradio',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--username', type=str, default='', help='Username for authentication'
|
||||
)
|
||||
|
|
@ -1414,9 +1438,10 @@ if __name__ == '__main__':
|
|||
'--inbrowser', action='store_true', help='Open in browser'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--listen',
|
||||
action='store_true',
|
||||
help='Launch gradio with server name 0.0.0.0, allowing LAN access',
|
||||
'--share', action='store_true', help='Share the gradio UI'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--headless', action='store_true', help='Is the server headless'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
@ -1426,4 +1451,7 @@ if __name__ == '__main__':
|
|||
password=args.password,
|
||||
inbrowser=args.inbrowser,
|
||||
server_port=args.server_port,
|
||||
share=args.share,
|
||||
listen=args.listen,
|
||||
headless=args.headless,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from library.common_gui import (
|
|||
# set_legacy_8bitadam,
|
||||
update_my_data,
|
||||
check_if_model_exist,
|
||||
output_message,
|
||||
)
|
||||
from library.tensorboard_gui import (
|
||||
gradio_tensorboard,
|
||||
|
|
@ -38,7 +39,7 @@ from library.dreambooth_folder_creation_gui import (
|
|||
)
|
||||
from library.utilities import utilities_tab
|
||||
from library.sampler_gui import sample_gradio_config, run_cmd_sample
|
||||
from easygui import msgbox
|
||||
# from easygui import msgbox
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
|
|
@ -275,6 +276,7 @@ def open_configuration(
|
|||
|
||||
|
||||
def train_model(
|
||||
headless,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
|
|
@ -349,49 +351,51 @@ def train_model(
|
|||
use_wandb,
|
||||
wandb_api_key,
|
||||
):
|
||||
headless_bool = True if headless.get('label') == 'True' else False
|
||||
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
output_message(msg='Source model information is missing', headless=headless_bool)
|
||||
return
|
||||
|
||||
if train_data_dir == '':
|
||||
msgbox('Image folder path is missing')
|
||||
output_message(msg='Image folder path is missing', headless=headless_bool)
|
||||
return
|
||||
|
||||
if not os.path.exists(train_data_dir):
|
||||
msgbox('Image folder does not exist')
|
||||
output_message(msg='Image folder does not exist', headless=headless_bool)
|
||||
return
|
||||
|
||||
if reg_data_dir != '':
|
||||
if not os.path.exists(reg_data_dir):
|
||||
msgbox('Regularisation folder does not exist')
|
||||
output_message(msg='Regularisation folder does not exist', headless=headless_bool)
|
||||
return
|
||||
|
||||
if output_dir == '':
|
||||
msgbox('Output folder path is missing')
|
||||
output_message(msg='Output folder path is missing', headless=headless_bool)
|
||||
return
|
||||
|
||||
if token_string == '':
|
||||
msgbox('Token string is missing')
|
||||
output_message(msg='Token string is missing', headless=headless_bool)
|
||||
return
|
||||
|
||||
if init_word == '':
|
||||
msgbox('Init word is missing')
|
||||
output_message(msg='Init word is missing', headless=headless_bool)
|
||||
return
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
if check_if_model_exist(output_name, output_dir, save_model_as):
|
||||
if check_if_model_exist(output_name, output_dir, save_model_as, headless_bool):
|
||||
return
|
||||
|
||||
if float(noise_offset) > 0 and (multires_noise_iterations > 0 or multires_noise_discount > 0):
|
||||
msgbox(msg='noise offset and multires_noise can\'t be set at the same time. Only use one or the other.', title='Error')
|
||||
output_message(msg='noise offset and multires_noise can\'t be set at the same time. Only use one or the other.', title='Error', headless=headless_bool)
|
||||
return
|
||||
|
||||
if optimizer == 'Adafactor' and lr_warmup != '0':
|
||||
msgbox(
|
||||
"Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
|
||||
title='Warning',
|
||||
output_message(
|
||||
msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
|
||||
title='Warning', headless=headless_bool
|
||||
)
|
||||
lr_warmup = '0'
|
||||
|
||||
|
|
@ -602,9 +606,11 @@ def ti_tab(
|
|||
reg_data_dir=gr.Textbox(),
|
||||
output_dir=gr.Textbox(),
|
||||
logging_dir=gr.Textbox(),
|
||||
headless=False
|
||||
):
|
||||
dummy_db_true = gr.Label(value=True, visible=False)
|
||||
dummy_db_false = gr.Label(value=False, visible=False)
|
||||
dummy_headless = gr.Label(value=headless, visible=False)
|
||||
gr.Markdown('Train a TI using kohya textual inversion python code...')
|
||||
(
|
||||
button_open_config,
|
||||
|
|
@ -612,7 +618,7 @@ def ti_tab(
|
|||
button_save_as_config,
|
||||
config_file_name,
|
||||
button_load_config,
|
||||
) = gradio_config()
|
||||
) = gradio_config(headless=headless)
|
||||
|
||||
(
|
||||
pretrained_model_name_or_path,
|
||||
|
|
@ -624,7 +630,7 @@ def ti_tab(
|
|||
save_model_as_choices=[
|
||||
'ckpt',
|
||||
'safetensors',
|
||||
]
|
||||
], headless=headless
|
||||
)
|
||||
|
||||
with gr.Tab('Folders'):
|
||||
|
|
@ -634,7 +640,7 @@ def ti_tab(
|
|||
placeholder='Folder where the training folders containing the images are located',
|
||||
)
|
||||
train_data_dir_input_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
'📂', elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
train_data_dir_input_folder.click(
|
||||
get_folder_path,
|
||||
|
|
@ -646,7 +652,7 @@ def ti_tab(
|
|||
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
||||
)
|
||||
reg_data_dir_input_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
'📂', elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
reg_data_dir_input_folder.click(
|
||||
get_folder_path,
|
||||
|
|
@ -659,7 +665,7 @@ def ti_tab(
|
|||
placeholder='Folder to output trained model',
|
||||
)
|
||||
output_dir_input_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
'📂', elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
output_dir_input_folder.click(
|
||||
get_folder_path,
|
||||
|
|
@ -671,7 +677,7 @@ def ti_tab(
|
|||
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
||||
)
|
||||
logging_dir_input_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
'📂', elem_id='open_folder_small', visible=(not headless)
|
||||
)
|
||||
logging_dir_input_folder.click(
|
||||
get_folder_path,
|
||||
|
|
@ -711,7 +717,7 @@ def ti_tab(
|
|||
label='Resume TI training',
|
||||
placeholder='(Optional) Path to existing TI embeding file to keep training',
|
||||
)
|
||||
weights_file_input = gr.Button('📂', elem_id='open_folder_small')
|
||||
weights_file_input = gr.Button('📂', elem_id='open_folder_small', visible=(not headless))
|
||||
weights_file_input.click(
|
||||
get_file_path,
|
||||
outputs=weights,
|
||||
|
|
@ -796,7 +802,7 @@ def ti_tab(
|
|||
label='VAE',
|
||||
placeholder='(Optiona) path to checkpoint of vae to replace for training',
|
||||
)
|
||||
vae_button = gr.Button('📂', elem_id='open_folder_small')
|
||||
vae_button = gr.Button('📂', elem_id='open_folder_small', visible=(not headless))
|
||||
vae_button.click(
|
||||
get_any_file_path,
|
||||
outputs=vae,
|
||||
|
|
@ -835,7 +841,7 @@ def ti_tab(
|
|||
save_last_n_steps_state,
|
||||
use_wandb,
|
||||
wandb_api_key,
|
||||
) = gradio_advanced_training()
|
||||
) = gradio_advanced_training(headless=headless)
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
inputs=[color_aug],
|
||||
|
|
@ -858,6 +864,7 @@ def ti_tab(
|
|||
reg_data_dir_input=reg_data_dir,
|
||||
output_dir_input=output_dir,
|
||||
logging_dir_input=logging_dir,
|
||||
headless=headless,
|
||||
)
|
||||
|
||||
button_run = gr.Button('Train model', variant='primary')
|
||||
|
|
@ -982,7 +989,7 @@ def ti_tab(
|
|||
|
||||
button_run.click(
|
||||
train_model,
|
||||
inputs=settings_list,
|
||||
inputs=[dummy_headless] + settings_list,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
|
@ -996,13 +1003,18 @@ def ti_tab(
|
|||
|
||||
def UI(**kwargs):
|
||||
css = ''
|
||||
|
||||
headless = kwargs.get('headless', False)
|
||||
print(f'headless: {headless}')
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
print('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
|
||||
interface = gr.Blocks(css=css)
|
||||
interface = gr.Blocks(
|
||||
css=css, title='Kohya_ss GUI', theme=gr.themes.Default()
|
||||
)
|
||||
|
||||
with interface:
|
||||
with gr.Tab('Dreambooth TI'):
|
||||
|
|
@ -1011,7 +1023,7 @@ def UI(**kwargs):
|
|||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
) = ti_tab()
|
||||
) = ti_tab(headless=headless)
|
||||
with gr.Tab('Utilities'):
|
||||
utilities_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
|
|
@ -1019,26 +1031,39 @@ def UI(**kwargs):
|
|||
output_dir_input=output_dir_input,
|
||||
logging_dir_input=logging_dir_input,
|
||||
enable_copy_info_button=True,
|
||||
headless=headless
|
||||
)
|
||||
|
||||
# Show the interface
|
||||
launch_kwargs = {}
|
||||
if not kwargs.get('username', None) == '':
|
||||
launch_kwargs['auth'] = (
|
||||
kwargs.get('username', None),
|
||||
kwargs.get('password', None),
|
||||
)
|
||||
if kwargs.get('server_port', 0) > 0:
|
||||
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
||||
if kwargs.get('inbrowser', False):
|
||||
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
|
||||
print(launch_kwargs)
|
||||
username = kwargs.get('username')
|
||||
password = kwargs.get('password')
|
||||
server_port = kwargs.get('server_port', 0)
|
||||
inbrowser = kwargs.get('inbrowser', False)
|
||||
share = kwargs.get('share', False)
|
||||
server_name = kwargs.get('listen')
|
||||
|
||||
launch_kwargs['server_name'] = server_name
|
||||
if username and password:
|
||||
launch_kwargs['auth'] = (username, password)
|
||||
if server_port > 0:
|
||||
launch_kwargs['server_port'] = server_port
|
||||
if inbrowser:
|
||||
launch_kwargs['inbrowser'] = inbrowser
|
||||
if share:
|
||||
launch_kwargs['share'] = share
|
||||
interface.launch(**launch_kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--listen',
|
||||
type=str,
|
||||
default='127.0.0.1',
|
||||
help='IP to listen on for connections to Gradio',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--username', type=str, default='', help='Username for authentication'
|
||||
)
|
||||
|
|
@ -1054,6 +1079,12 @@ if __name__ == '__main__':
|
|||
parser.add_argument(
|
||||
'--inbrowser', action='store_true', help='Open in browser'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--share', action='store_true', help='Share the gradio UI'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--headless', action='store_true', help='Is the server headless'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -1062,4 +1093,7 @@ if __name__ == '__main__':
|
|||
password=args.password,
|
||||
inbrowser=args.inbrowser,
|
||||
server_port=args.server_port,
|
||||
share=args.share,
|
||||
listen=args.listen,
|
||||
headless=args.headless,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue