Update LoRA GUI

Various improvements
pull/22/head
bmaltais 2023-01-01 14:14:58 -05:00
parent ee2499d834
commit af46ce4c47
6 changed files with 61 additions and 34 deletions

View File

@ -12,18 +12,38 @@ You can find the dreambooth solution spercific [Dreambooth README](README_dreamb
You can find the finetune solution spercific [Finetune README](README_finetune.md)
## LoRA
You can create LoRA network by running the dedicated GUI with:
```
python lora_gui.py
```
or via the all in one GUI:
```
python kahya_gui.py
```
Once you have created the LoRA network you can generate images via auto1111 by installing the extension found here: https://github.com/kohya-ss/sd-webui-additional-networks
## Change history
* 12/30 (v19) update:
* 2023/01/01 (v19.1) update:
- merge kohys_ss upstream code updates
- rework Dreambooth LoRA GUI
- fix bug where LoRA network weights were not loaded to properly resume training
* 2022/12/30 (v19) update:
- support for LoRA network training in kohya_gui.py.
* 12/23 (v18.8) update:
* 2022/12/23 (v18.8) update:
- Fix for conversion tool issue when the source was an sd1.x diffuser model
- Other minor code and GUI fix
* 12/22 (v18.7) update:
* 2022/12/22 (v18.7) update:
- Merge dreambooth and finetune is a common GUI
- General bug fixes and code improvements
* 12/21 (v18.6.1) update:
* 2022/12/21 (v18.6.1) update:
- fix issue with dataset balancing when the number of detected images in the folder is 0
* 12/21 (v18.6) update:
* 2022/12/21 (v18.6) update:
- add optional GUI authentication support via: `python fine_tune.py --username=<name> --password=<password>`

View File

@ -15,6 +15,7 @@ from library.common_gui import (
get_folder_path,
remove_doublequote,
get_file_path,
get_any_file_path,
get_saveasfile_path,
)
from library.dreambooth_folder_creation_gui import (
@ -236,7 +237,7 @@ def train_model(
seed,
num_cpu_threads_per_process,
cache_latent,
caption_extention,
caption_extension,
enable_bucket,
gradient_checkpointing,
full_fp16,
@ -396,7 +397,8 @@ def train_model(
run_cmd += f' --seed={seed}'
run_cmd += f' --save_precision={save_precision}'
run_cmd += f' --logging_dir={logging_dir}'
run_cmd += f' --caption_extention={caption_extention}'
if not caption_extension == '':
run_cmd += f' --caption_extension={caption_extension}'
if not stop_text_encoder_training == 0:
run_cmd += (
f' --stop_text_encoder_training={stop_text_encoder_training}'
@ -542,7 +544,7 @@ def dreambooth_tab(
document_symbol, elem_id='open_folder_small'
)
pretrained_model_name_or_path_fille.click(
get_file_path,
get_any_file_path,
inputs=[pretrained_model_name_or_path_input],
outputs=pretrained_model_name_or_path_input,
)

View File

@ -1,12 +1,12 @@
$txt_files_folder = "D:\dreambooth\training_twq\mad_hatter\all"
$txt_prefix_to_ignore = "asd"
$txt_postfix_ti_ignore = "asd"
$txt_files_folder = "D:\dataset\metart_g1\img\100_asd girl"
$txt_prefix_to_ignore = "asds"
$txt_postfix_ti_ignore = "asds"
# Should not need to touch anything below
# (Get-Content $txt_files_folder"\*.txt" ).Replace(",", "") -Split '\W' | Group-Object -NoElement | Sort-Object -Descending -Property Count
$combined_txt = Get-Content $txt_files_folder"\*.txt"
$combined_txt = Get-Content $txt_files_folder"\*.cap"
$combined_txt = $combined_txt.Replace(",", "")
$combined_txt = $combined_txt.Replace("$txt_prefix_to_ignore", "")
$combined_txt = $combined_txt.Replace("$txt_postfix_ti_ignore", "") -Split '\W' | Group-Object -NoElement | Sort-Object -Descending -Property Count

View File

@ -9,6 +9,7 @@ import argparse
from library.common_gui import (
get_folder_path,
get_file_path,
get_any_file_path,
get_saveasfile_path,
)
from library.utilities import utilities_tab
@ -436,7 +437,7 @@ def finetune_tab():
document_symbol, elem_id='open_folder_small'
)
pretrained_model_name_or_path_file.click(
get_file_path,
get_any_file_path,
inputs=pretrained_model_name_or_path_input,
outputs=pretrained_model_name_or_path_input,
)

View File

@ -75,7 +75,7 @@ def gradio_basic_caption_gui_tab():
)
with gr.Row():
prefix = gr.Textbox(
label='Prefix to add to txt caption',
label='Prefix to add to caption',
placeholder='(Optional)',
interactive=True,
)
@ -85,7 +85,7 @@ def gradio_basic_caption_gui_tab():
interactive=True,
)
postfix = gr.Textbox(
label='Postfix to add to txt caption',
label='Postfix to add to caption',
placeholder='(Optional)',
interactive=True,
)

View File

@ -64,7 +64,7 @@ def save_configuration(
shuffle_caption,
save_state,
resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
):
original_file_path = file_path
@ -118,7 +118,8 @@ def save_configuration(
'prior_loss_weight': prior_loss_weight,
'text_encoder_lr': text_encoder_lr,
'unet_lr': unet_lr,
'network_dim': network_dim
'network_dim': network_dim,
'lora_network_weights': lora_network_weights,
}
# Save the data to the selected file
@ -160,7 +161,7 @@ def open_configuration(
shuffle_caption,
save_state,
resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
):
original_file_path = file_path
@ -216,6 +217,7 @@ def open_configuration(
my_data.get('text_encoder_lr', text_encoder_lr),
my_data.get('unet_lr', unet_lr),
my_data.get('network_dim', network_dim),
my_data.get('lora_network_weights', lora_network_weights),
)
@ -250,7 +252,7 @@ def train_model(
shuffle_caption,
save_state,
resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
):
def save_inference_file(output_dir, v2, v_parameterization):
# Copy inference model for v2 if required
@ -432,6 +434,7 @@ def train_model(
# elif network_train == 'Unet only':
# run_cmd += f' --network_train_unet_only'
run_cmd += f' --network_dim={network_dim}'
run_cmd += f' --network_weights={lora_network_weights}'
print(run_cmd)
@ -568,7 +571,7 @@ def lora_tab(
document_symbol, elem_id='open_folder_small'
)
pretrained_model_name_or_path_file.click(
get_file_path,
get_any_file_path,
inputs=[pretrained_model_name_or_path_input],
outputs=pretrained_model_name_or_path_input,
)
@ -602,19 +605,7 @@ def lora_tab(
],
value='same as source model',
)
with gr.Row():
lora_network_weights = gr.Textbox(
label='LoRA network weights',
placeholder='{Optional) Path to existing LoRA network weights to resume training}',
)
lora_network_weights_file = gr.Button(
document_symbol, elem_id='open_folder_small'
)
lora_network_weights_file.click(
get_any_file_path,
inputs=[lora_network_weights],
outputs=lora_network_weights,
)
with gr.Row():
v2_input = gr.Checkbox(label='v2', value=True)
v_parameterization_input = gr.Checkbox(
@ -699,6 +690,19 @@ def lora_tab(
outputs=[logging_dir_input],
)
with gr.Tab('Training parameters'):
with gr.Row():
lora_network_weights = gr.Textbox(
label='LoRA network weights',
placeholder='{Optional) Path to existing LoRA network weights to resume training',
)
lora_network_weights_file = gr.Button(
document_symbol, elem_id='open_folder_small'
)
lora_network_weights_file.click(
get_any_file_path,
inputs=[lora_network_weights],
outputs=lora_network_weights,
)
with gr.Row():
# learning_rate_input = gr.Textbox(label='Learning rate', value=1e-4, visible=False)
lr_scheduler_input = gr.Dropdown(
@ -874,7 +878,7 @@ def lora_tab(
shuffle_caption,
save_state,
resume,
prior_loss_weight, text_encoder_lr, unet_lr, network_dim
prior_loss_weight, text_encoder_lr, unet_lr, network_dim, lora_network_weights
]
button_open_config.click(