mirror of https://github.com/bmaltais/kohya_ss
Format code
parent
08ce96f33b
commit
0b217a4cf8
|
|
@ -551,14 +551,15 @@ def train_model(
|
|||
|
||||
# run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"'
|
||||
run_cmd = "accelerate launch"
|
||||
|
||||
|
||||
run_cmd += run_cmd_advanced_training(
|
||||
num_processes=num_processes,
|
||||
num_machines=num_machines,
|
||||
multi_gpu=multi_gpu,
|
||||
gpu_ids=gpu_ids,
|
||||
num_cpu_threads_per_process=num_cpu_threads_per_process)
|
||||
|
||||
num_cpu_threads_per_process=num_cpu_threads_per_process,
|
||||
)
|
||||
|
||||
if sdxl:
|
||||
run_cmd += f' "./sdxl_train.py"'
|
||||
else:
|
||||
|
|
|
|||
374
finetune_gui.py
374
finetune_gui.py
|
|
@ -43,12 +43,12 @@ executor = CommandExecutor()
|
|||
|
||||
# from easygui import msgbox
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
folder_symbol = "\U0001f4c2" # 📂
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
save_style_symbol = "\U0001f4be" # 💾
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
|
||||
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
|
||||
PYTHON = "python3" if os.name == "posix" else "./venv/Scripts/python.exe"
|
||||
|
||||
|
||||
def save_configuration(
|
||||
|
|
@ -100,7 +100,8 @@ def save_configuration(
|
|||
save_state,
|
||||
resume,
|
||||
gradient_checkpointing,
|
||||
gradient_accumulation_steps,block_lr,
|
||||
gradient_accumulation_steps,
|
||||
block_lr,
|
||||
mem_eff_attn,
|
||||
shuffle_caption,
|
||||
output_name,
|
||||
|
|
@ -153,19 +154,19 @@ def save_configuration(
|
|||
|
||||
original_file_path = file_path
|
||||
|
||||
save_as_bool = True if save_as.get('label') == 'True' else False
|
||||
save_as_bool = True if save_as.get("label") == "True" else False
|
||||
|
||||
if save_as_bool:
|
||||
log.info('Save as...')
|
||||
log.info("Save as...")
|
||||
file_path = get_saveasfile_path(file_path)
|
||||
else:
|
||||
log.info('Save...')
|
||||
if file_path == None or file_path == '':
|
||||
log.info("Save...")
|
||||
if file_path == None or file_path == "":
|
||||
file_path = get_saveasfile_path(file_path)
|
||||
|
||||
# log.info(file_path)
|
||||
|
||||
if file_path == None or file_path == '':
|
||||
if file_path == None or file_path == "":
|
||||
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||
|
||||
# Extract the destination directory from the file path
|
||||
|
|
@ -178,7 +179,7 @@ def save_configuration(
|
|||
SaveConfigFile(
|
||||
parameters=parameters,
|
||||
file_path=file_path,
|
||||
exclusion=['file_path', 'save_as'],
|
||||
exclusion=["file_path", "save_as"],
|
||||
)
|
||||
|
||||
return file_path
|
||||
|
|
@ -234,7 +235,8 @@ def open_configuration(
|
|||
save_state,
|
||||
resume,
|
||||
gradient_checkpointing,
|
||||
gradient_accumulation_steps,block_lr,
|
||||
gradient_accumulation_steps,
|
||||
block_lr,
|
||||
mem_eff_attn,
|
||||
shuffle_caption,
|
||||
output_name,
|
||||
|
|
@ -286,33 +288,31 @@ def open_configuration(
|
|||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
||||
ask_for_file = True if ask_for_file.get('label') == 'True' else False
|
||||
apply_preset = True if apply_preset.get('label') == 'True' else False
|
||||
ask_for_file = True if ask_for_file.get("label") == "True" else False
|
||||
apply_preset = True if apply_preset.get("label") == "True" else False
|
||||
|
||||
# Check if we are "applying" a preset or a config
|
||||
if apply_preset:
|
||||
log.info(f'Applying preset {training_preset}...')
|
||||
file_path = f'./presets/finetune/{training_preset}.json'
|
||||
log.info(f"Applying preset {training_preset}...")
|
||||
file_path = f"./presets/finetune/{training_preset}.json"
|
||||
else:
|
||||
# If not applying a preset, set the `training_preset` field to an empty string
|
||||
# Find the index of the `training_preset` parameter using the `index()` method
|
||||
training_preset_index = parameters.index(
|
||||
('training_preset', training_preset)
|
||||
)
|
||||
training_preset_index = parameters.index(("training_preset", training_preset))
|
||||
|
||||
# Update the value of `training_preset` by directly assigning an empty string value
|
||||
parameters[training_preset_index] = ('training_preset', '')
|
||||
parameters[training_preset_index] = ("training_preset", "")
|
||||
|
||||
original_file_path = file_path
|
||||
|
||||
if ask_for_file:
|
||||
file_path = get_file_path(file_path)
|
||||
|
||||
if not file_path == '' and not file_path == None:
|
||||
if not file_path == "" and not file_path == None:
|
||||
# load variables from JSON file
|
||||
with open(file_path, 'r') as f:
|
||||
with open(file_path, "r") as f:
|
||||
my_data = json.load(f)
|
||||
log.info('Loading config...')
|
||||
log.info("Loading config...")
|
||||
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
|
||||
my_data = update_my_data(my_data)
|
||||
else:
|
||||
|
|
@ -323,7 +323,7 @@ def open_configuration(
|
|||
for key, value in parameters:
|
||||
json_value = my_data.get(key)
|
||||
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
||||
if not key in ['ask_for_file', 'apply_preset', 'file_path']:
|
||||
if not key in ["ask_for_file", "apply_preset", "file_path"]:
|
||||
values.append(json_value if json_value is not None else value)
|
||||
return tuple(values)
|
||||
|
||||
|
|
@ -377,7 +377,8 @@ def train_model(
|
|||
save_state,
|
||||
resume,
|
||||
gradient_checkpointing,
|
||||
gradient_accumulation_steps,block_lr,
|
||||
gradient_accumulation_steps,
|
||||
block_lr,
|
||||
mem_eff_attn,
|
||||
shuffle_caption,
|
||||
output_name,
|
||||
|
|
@ -428,14 +429,12 @@ def train_model(
|
|||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
||||
print_only_bool = True if print_only.get('label') == 'True' else False
|
||||
log.info(f'Start Finetuning...')
|
||||
print_only_bool = True if print_only.get("label") == "True" else False
|
||||
log.info(f"Start Finetuning...")
|
||||
|
||||
headless_bool = True if headless.get('label') == 'True' else False
|
||||
headless_bool = True if headless.get("label") == "True" else False
|
||||
|
||||
if check_if_model_exist(
|
||||
output_name, output_dir, save_model_as, headless_bool
|
||||
):
|
||||
if check_if_model_exist(output_name, output_dir, save_model_as, headless_bool):
|
||||
return
|
||||
|
||||
# if float(noise_offset) > 0 and (
|
||||
|
|
@ -461,52 +460,50 @@ def train_model(
|
|||
if not os.path.exists(train_dir):
|
||||
os.mkdir(train_dir)
|
||||
|
||||
run_cmd = f'{PYTHON} finetune/merge_captions_to_metadata.py'
|
||||
if caption_extension == '':
|
||||
run_cmd = f"{PYTHON} finetune/merge_captions_to_metadata.py"
|
||||
if caption_extension == "":
|
||||
run_cmd += f' --caption_extension=".caption"'
|
||||
else:
|
||||
run_cmd += f' --caption_extension={caption_extension}'
|
||||
run_cmd += f" --caption_extension={caption_extension}"
|
||||
run_cmd += f' "{image_folder}"'
|
||||
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
|
||||
if full_path:
|
||||
run_cmd += f' --full_path'
|
||||
run_cmd += f" --full_path"
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
if not print_only_bool:
|
||||
# Run the command
|
||||
if os.name == 'posix':
|
||||
if os.name == "posix":
|
||||
os.system(run_cmd)
|
||||
else:
|
||||
subprocess.run(run_cmd)
|
||||
|
||||
# create images buckets
|
||||
if generate_image_buckets:
|
||||
run_cmd = f'{PYTHON} finetune/prepare_buckets_latents.py'
|
||||
run_cmd = f"{PYTHON} finetune/prepare_buckets_latents.py"
|
||||
run_cmd += f' "{image_folder}"'
|
||||
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
|
||||
run_cmd += f' "{train_dir}/{latent_metadata_filename}"'
|
||||
run_cmd += f' "{pretrained_model_name_or_path}"'
|
||||
run_cmd += f' --batch_size={batch_size}'
|
||||
run_cmd += f' --max_resolution={max_resolution}'
|
||||
run_cmd += f' --min_bucket_reso={min_bucket_reso}'
|
||||
run_cmd += f' --max_bucket_reso={max_bucket_reso}'
|
||||
run_cmd += f' --mixed_precision={mixed_precision}'
|
||||
run_cmd += f" --batch_size={batch_size}"
|
||||
run_cmd += f" --max_resolution={max_resolution}"
|
||||
run_cmd += f" --min_bucket_reso={min_bucket_reso}"
|
||||
run_cmd += f" --max_bucket_reso={max_bucket_reso}"
|
||||
run_cmd += f" --mixed_precision={mixed_precision}"
|
||||
# if flip_aug:
|
||||
# run_cmd += f' --flip_aug'
|
||||
if full_path:
|
||||
run_cmd += f' --full_path'
|
||||
run_cmd += f" --full_path"
|
||||
if sdxl_checkbox and sdxl_no_half_vae:
|
||||
log.info(
|
||||
'Using mixed_precision = no because no half vae is selected...'
|
||||
)
|
||||
log.info("Using mixed_precision = no because no half vae is selected...")
|
||||
run_cmd += f' --mixed_precision="no"'
|
||||
|
||||
log.info(run_cmd)
|
||||
|
||||
if not print_only_bool:
|
||||
# Run the command
|
||||
if os.name == 'posix':
|
||||
if os.name == "posix":
|
||||
os.system(run_cmd)
|
||||
else:
|
||||
subprocess.run(run_cmd)
|
||||
|
|
@ -517,13 +514,13 @@ def train_model(
|
|||
for f, lower_f in (
|
||||
(file, file.lower()) for file in os.listdir(image_folder)
|
||||
)
|
||||
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
||||
if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp"))
|
||||
]
|
||||
)
|
||||
log.info(f'image_num = {image_num}')
|
||||
log.info(f"image_num = {image_num}")
|
||||
|
||||
repeats = int(image_num) * int(dataset_repeats)
|
||||
log.info(f'repeats = {str(repeats)}')
|
||||
log.info(f"repeats = {str(repeats)}")
|
||||
|
||||
# calculate max_train_steps
|
||||
max_train_steps = int(
|
||||
|
|
@ -539,26 +536,31 @@ def train_model(
|
|||
if flip_aug:
|
||||
max_train_steps = int(math.ceil(float(max_train_steps) / 2))
|
||||
|
||||
log.info(f'max_train_steps = {max_train_steps}')
|
||||
log.info(f"max_train_steps = {max_train_steps}")
|
||||
|
||||
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
||||
log.info(f'lr_warmup_steps = {lr_warmup_steps}')
|
||||
log.info(f"lr_warmup_steps = {lr_warmup_steps}")
|
||||
|
||||
run_cmd = "accelerate launch"
|
||||
|
||||
|
||||
run_cmd += run_cmd_advanced_training(
|
||||
num_processes=num_processes,
|
||||
num_machines=num_machines,
|
||||
multi_gpu=multi_gpu,
|
||||
gpu_ids=gpu_ids,
|
||||
num_cpu_threads_per_process=num_cpu_threads_per_process)
|
||||
|
||||
num_cpu_threads_per_process=num_cpu_threads_per_process,
|
||||
)
|
||||
|
||||
if sdxl_checkbox:
|
||||
run_cmd += f' "./sdxl_train.py"'
|
||||
else:
|
||||
run_cmd += f' "./fine_tune.py"'
|
||||
|
||||
in_json = f'{train_dir}/{latent_metadata_filename}' if use_latent_files == 'Yes' else f'{train_dir}/{caption_metadata_filename}'
|
||||
|
||||
in_json = (
|
||||
f"{train_dir}/{latent_metadata_filename}"
|
||||
if use_latent_files == "Yes"
|
||||
else f"{train_dir}/{caption_metadata_filename}"
|
||||
)
|
||||
cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs
|
||||
no_half_vae = sdxl_checkbox and sdxl_no_half_vae
|
||||
|
||||
|
|
@ -570,7 +572,9 @@ def train_model(
|
|||
bucket_reso_steps=bucket_reso_steps,
|
||||
cache_latents=cache_latents,
|
||||
cache_latents_to_disk=cache_latents_to_disk,
|
||||
cache_text_encoder_outputs=cache_text_encoder_outputs if sdxl_checkbox else None,
|
||||
cache_text_encoder_outputs=cache_text_encoder_outputs
|
||||
if sdxl_checkbox
|
||||
else None,
|
||||
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
|
||||
caption_dropout_rate=caption_dropout_rate,
|
||||
caption_extension=caption_extension,
|
||||
|
|
@ -651,7 +655,7 @@ def train_model(
|
|||
|
||||
if print_only_bool:
|
||||
log.warning(
|
||||
'Here is the trainer command as a reference. It will not be executed:\n'
|
||||
"Here is the trainer command as a reference. It will not be executed:\n"
|
||||
)
|
||||
print(run_cmd)
|
||||
|
||||
|
|
@ -659,17 +663,15 @@ def train_model(
|
|||
else:
|
||||
# Saving config file for model
|
||||
current_datetime = datetime.now()
|
||||
formatted_datetime = current_datetime.strftime('%Y%m%d-%H%M%S')
|
||||
file_path = os.path.join(
|
||||
output_dir, f'{output_name}_{formatted_datetime}.json'
|
||||
)
|
||||
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
|
||||
file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json")
|
||||
|
||||
log.info(f'Saving training config to {file_path}...')
|
||||
log.info(f"Saving training config to {file_path}...")
|
||||
|
||||
SaveConfigFile(
|
||||
parameters=parameters,
|
||||
file_path=file_path,
|
||||
exclusion=['file_path', 'save_as', 'headless', 'print_only'],
|
||||
exclusion=["file_path", "save_as", "headless", "print_only"],
|
||||
)
|
||||
|
||||
log.info(run_cmd)
|
||||
|
|
@ -678,18 +680,16 @@ def train_model(
|
|||
executor.execute_command(run_cmd=run_cmd)
|
||||
|
||||
# check if output_dir/last is a folder... therefore it is a diffuser model
|
||||
last_dir = pathlib.Path(f'{output_dir}/{output_name}')
|
||||
last_dir = pathlib.Path(f"{output_dir}/{output_name}")
|
||||
|
||||
if not last_dir.is_dir():
|
||||
# Copy inference model for v2 if required
|
||||
save_inference_file(
|
||||
output_dir, v2, v_parameterization, output_name
|
||||
)
|
||||
save_inference_file(output_dir, v2, v_parameterization, output_name)
|
||||
|
||||
|
||||
def remove_doublequote(file_path):
|
||||
if file_path != None:
|
||||
file_path = file_path.replace('"', '')
|
||||
file_path = file_path.replace('"', "")
|
||||
|
||||
return file_path
|
||||
|
||||
|
|
@ -698,23 +698,23 @@ def finetune_tab(headless=False):
|
|||
dummy_db_true = gr.Label(value=True, visible=False)
|
||||
dummy_db_false = gr.Label(value=False, visible=False)
|
||||
dummy_headless = gr.Label(value=headless, visible=False)
|
||||
with gr.Tab('Training'):
|
||||
gr.Markdown('Train a custom model using kohya finetune python code...')
|
||||
with gr.Tab("Training"):
|
||||
gr.Markdown("Train a custom model using kohya finetune python code...")
|
||||
|
||||
# Setup Configuration Files Gradio
|
||||
config = ConfigurationFile(headless)
|
||||
|
||||
source_model = SourceModel(headless=headless)
|
||||
|
||||
with gr.Tab('Folders'):
|
||||
with gr.Tab("Folders"):
|
||||
with gr.Row():
|
||||
train_dir = gr.Textbox(
|
||||
label='Training config folder',
|
||||
placeholder='folder where the training configuration files will be saved',
|
||||
label="Training config folder",
|
||||
placeholder="folder where the training configuration files will be saved",
|
||||
)
|
||||
train_dir_folder = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
)
|
||||
train_dir_folder.click(
|
||||
|
|
@ -724,12 +724,12 @@ def finetune_tab(headless=False):
|
|||
)
|
||||
|
||||
image_folder = gr.Textbox(
|
||||
label='Training Image folder',
|
||||
placeholder='folder where the training images are located',
|
||||
label="Training Image folder",
|
||||
placeholder="folder where the training images are located",
|
||||
)
|
||||
image_folder_input_folder = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
)
|
||||
image_folder_input_folder.click(
|
||||
|
|
@ -739,12 +739,12 @@ def finetune_tab(headless=False):
|
|||
)
|
||||
with gr.Row():
|
||||
output_dir = gr.Textbox(
|
||||
label='Model output folder',
|
||||
placeholder='folder where the model will be saved',
|
||||
label="Model output folder",
|
||||
placeholder="folder where the model will be saved",
|
||||
)
|
||||
output_dir_input_folder = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
)
|
||||
output_dir_input_folder.click(
|
||||
|
|
@ -754,12 +754,12 @@ def finetune_tab(headless=False):
|
|||
)
|
||||
|
||||
logging_dir = gr.Textbox(
|
||||
label='Logging folder',
|
||||
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
||||
label="Logging folder",
|
||||
placeholder="Optional: enable logging and output TensorBoard log to this folder",
|
||||
)
|
||||
logging_dir_input_folder = gr.Button(
|
||||
folder_symbol,
|
||||
elem_id='open_folder_small',
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
)
|
||||
logging_dir_input_folder.click(
|
||||
|
|
@ -769,9 +769,9 @@ def finetune_tab(headless=False):
|
|||
)
|
||||
with gr.Row():
|
||||
output_name = gr.Textbox(
|
||||
label='Model output name',
|
||||
placeholder='Name of the model to output',
|
||||
value='last',
|
||||
label="Model output name",
|
||||
placeholder="Name of the model to output",
|
||||
value="last",
|
||||
interactive=True,
|
||||
)
|
||||
train_dir.change(
|
||||
|
|
@ -789,102 +789,96 @@ def finetune_tab(headless=False):
|
|||
inputs=[output_dir],
|
||||
outputs=[output_dir],
|
||||
)
|
||||
with gr.Tab('Dataset preparation'):
|
||||
with gr.Tab("Dataset preparation"):
|
||||
with gr.Row():
|
||||
max_resolution = gr.Textbox(
|
||||
label='Resolution (width,height)', value='512,512'
|
||||
)
|
||||
min_bucket_reso = gr.Textbox(
|
||||
label='Min bucket resolution', value='256'
|
||||
label="Resolution (width,height)", value="512,512"
|
||||
)
|
||||
min_bucket_reso = gr.Textbox(label="Min bucket resolution", value="256")
|
||||
max_bucket_reso = gr.Textbox(
|
||||
label='Max bucket resolution', value='1024'
|
||||
label="Max bucket resolution", value="1024"
|
||||
)
|
||||
batch_size = gr.Textbox(label='Batch size', value='1')
|
||||
batch_size = gr.Textbox(label="Batch size", value="1")
|
||||
with gr.Row():
|
||||
create_caption = gr.Checkbox(
|
||||
label='Generate caption metadata', value=True
|
||||
label="Generate caption metadata", value=True
|
||||
)
|
||||
create_buckets = gr.Checkbox(
|
||||
label='Generate image buckets metadata', value=True
|
||||
label="Generate image buckets metadata", value=True
|
||||
)
|
||||
use_latent_files = gr.Dropdown(
|
||||
label='Use latent files',
|
||||
label="Use latent files",
|
||||
choices=[
|
||||
'No',
|
||||
'Yes',
|
||||
"No",
|
||||
"Yes",
|
||||
],
|
||||
value='Yes',
|
||||
value="Yes",
|
||||
)
|
||||
with gr.Accordion('Advanced parameters', open=False):
|
||||
with gr.Accordion("Advanced parameters", open=False):
|
||||
with gr.Row():
|
||||
caption_metadata_filename = gr.Textbox(
|
||||
label='Caption metadata filename',
|
||||
value='meta_cap.json',
|
||||
label="Caption metadata filename",
|
||||
value="meta_cap.json",
|
||||
)
|
||||
latent_metadata_filename = gr.Textbox(
|
||||
label='Latent metadata filename', value='meta_lat.json'
|
||||
label="Latent metadata filename", value="meta_lat.json"
|
||||
)
|
||||
with gr.Row():
|
||||
full_path = gr.Checkbox(label='Use full path', value=True)
|
||||
full_path = gr.Checkbox(label="Use full path", value=True)
|
||||
weighted_captions = gr.Checkbox(
|
||||
label='Weighted captions', value=False
|
||||
label="Weighted captions", value=False
|
||||
)
|
||||
with gr.Tab('Parameters'):
|
||||
|
||||
with gr.Tab("Parameters"):
|
||||
|
||||
def list_presets(path):
|
||||
json_files = []
|
||||
|
||||
for file in os.listdir(path):
|
||||
if file.endswith('.json'):
|
||||
if file.endswith(".json"):
|
||||
json_files.append(os.path.splitext(file)[0])
|
||||
|
||||
user_presets_path = os.path.join(path, 'user_presets')
|
||||
user_presets_path = os.path.join(path, "user_presets")
|
||||
if os.path.isdir(user_presets_path):
|
||||
for file in os.listdir(user_presets_path):
|
||||
if file.endswith('.json'):
|
||||
if file.endswith(".json"):
|
||||
preset_name = os.path.splitext(file)[0]
|
||||
json_files.append(
|
||||
os.path.join('user_presets', preset_name)
|
||||
)
|
||||
json_files.append(os.path.join("user_presets", preset_name))
|
||||
|
||||
return json_files
|
||||
|
||||
training_preset = gr.Dropdown(
|
||||
label='Presets',
|
||||
choices=list_presets('./presets/finetune'),
|
||||
elem_id='myDropdown',
|
||||
label="Presets",
|
||||
choices=list_presets("./presets/finetune"),
|
||||
elem_id="myDropdown",
|
||||
)
|
||||
|
||||
with gr.Tab('Basic', elem_id='basic_tab'):
|
||||
|
||||
with gr.Tab("Basic", elem_id="basic_tab"):
|
||||
basic_training = BasicTraining(
|
||||
learning_rate_value='1e-5', finetuning=True, sdxl_checkbox=source_model.sdxl_checkbox,
|
||||
learning_rate_value="1e-5",
|
||||
finetuning=True,
|
||||
sdxl_checkbox=source_model.sdxl_checkbox,
|
||||
)
|
||||
|
||||
# Add SDXL Parameters
|
||||
sdxl_params = SDXLParameters(source_model.sdxl_checkbox)
|
||||
|
||||
with gr.Row():
|
||||
dataset_repeats = gr.Textbox(
|
||||
label='Dataset repeats', value=40
|
||||
)
|
||||
dataset_repeats = gr.Textbox(label="Dataset repeats", value=40)
|
||||
train_text_encoder = gr.Checkbox(
|
||||
label='Train text encoder', value=True
|
||||
label="Train text encoder", value=True
|
||||
)
|
||||
|
||||
with gr.Tab('Advanced', elem_id='advanced_tab'):
|
||||
with gr.Tab("Advanced", elem_id="advanced_tab"):
|
||||
with gr.Row():
|
||||
gradient_accumulation_steps = gr.Number(
|
||||
label='Gradient accumulate steps', value='1'
|
||||
label="Gradient accumulate steps", value="1"
|
||||
)
|
||||
block_lr = gr.Textbox(
|
||||
label='Block LR',
|
||||
placeholder='(Optional)',
|
||||
info='Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3',
|
||||
label="Block LR",
|
||||
placeholder="(Optional)",
|
||||
info="Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3",
|
||||
)
|
||||
advanced_training = AdvancedTraining(
|
||||
headless=headless, finetuning=True
|
||||
)
|
||||
advanced_training = AdvancedTraining(headless=headless, finetuning=True)
|
||||
advanced_training.color_aug.change(
|
||||
color_aug_changed,
|
||||
inputs=[advanced_training.color_aug],
|
||||
|
|
@ -893,15 +887,15 @@ def finetune_tab(headless=False):
|
|||
], # Not applicable to fine_tune.py
|
||||
)
|
||||
|
||||
with gr.Tab('Samples', elem_id='samples_tab'):
|
||||
with gr.Tab("Samples", elem_id="samples_tab"):
|
||||
sample = SampleImages()
|
||||
|
||||
with gr.Row():
|
||||
button_run = gr.Button('Start training', variant='primary')
|
||||
button_run = gr.Button("Start training", variant="primary")
|
||||
|
||||
button_stop_training = gr.Button('Stop training')
|
||||
button_stop_training = gr.Button("Stop training")
|
||||
|
||||
button_print = gr.Button('Print training command')
|
||||
button_print = gr.Button("Print training command")
|
||||
|
||||
# Setup gradio tensorboard buttons
|
||||
(
|
||||
|
|
@ -1020,9 +1014,7 @@ def finetune_tab(headless=False):
|
|||
inputs=[dummy_db_true, dummy_db_false, config.config_file_name]
|
||||
+ settings_list
|
||||
+ [training_preset],
|
||||
outputs=[config.config_file_name]
|
||||
+ settings_list
|
||||
+ [training_preset],
|
||||
outputs=[config.config_file_name] + settings_list + [training_preset],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
|
@ -1038,9 +1030,7 @@ def finetune_tab(headless=False):
|
|||
inputs=[dummy_db_false, dummy_db_false, config.config_file_name]
|
||||
+ settings_list
|
||||
+ [training_preset],
|
||||
outputs=[config.config_file_name]
|
||||
+ settings_list
|
||||
+ [training_preset],
|
||||
outputs=[config.config_file_name] + settings_list + [training_preset],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
|
@ -1056,9 +1046,7 @@ def finetune_tab(headless=False):
|
|||
inputs=[dummy_db_false, dummy_db_true, config.config_file_name]
|
||||
+ settings_list
|
||||
+ [training_preset],
|
||||
outputs=[gr.Textbox()]
|
||||
+ settings_list
|
||||
+ [training_preset],
|
||||
outputs=[gr.Textbox()] + settings_list + [training_preset],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
|
@ -1090,94 +1078,84 @@ def finetune_tab(headless=False):
|
|||
show_progress=False,
|
||||
)
|
||||
|
||||
with gr.Tab('Guides'):
|
||||
gr.Markdown(
|
||||
'This section provide Various Finetuning guides and information...'
|
||||
)
|
||||
top_level_path = './docs/Finetuning/top_level.md'
|
||||
with gr.Tab("Guides"):
|
||||
gr.Markdown("This section provide Various Finetuning guides and information...")
|
||||
top_level_path = "./docs/Finetuning/top_level.md"
|
||||
if os.path.exists(top_level_path):
|
||||
with open(
|
||||
os.path.join(top_level_path), 'r', encoding='utf8'
|
||||
) as file:
|
||||
guides_top_level = file.read() + '\n'
|
||||
with open(os.path.join(top_level_path), "r", encoding="utf8") as file:
|
||||
guides_top_level = file.read() + "\n"
|
||||
gr.Markdown(guides_top_level)
|
||||
|
||||
|
||||
def UI(**kwargs):
|
||||
add_javascript(kwargs.get('language'))
|
||||
css = ''
|
||||
add_javascript(kwargs.get("language"))
|
||||
css = ""
|
||||
|
||||
headless = kwargs.get('headless', False)
|
||||
log.info(f'headless: {headless}')
|
||||
headless = kwargs.get("headless", False)
|
||||
log.info(f"headless: {headless}")
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
log.info('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
if os.path.exists("./style.css"):
|
||||
with open(os.path.join("./style.css"), "r", encoding="utf8") as file:
|
||||
log.info("Load CSS...")
|
||||
css += file.read() + "\n"
|
||||
|
||||
interface = gr.Blocks(
|
||||
css=css, title='Kohya_ss GUI', theme=gr.themes.Default()
|
||||
)
|
||||
interface = gr.Blocks(css=css, title="Kohya_ss GUI", theme=gr.themes.Default())
|
||||
|
||||
with interface:
|
||||
with gr.Tab('Finetune'):
|
||||
with gr.Tab("Finetune"):
|
||||
finetune_tab(headless=headless)
|
||||
with gr.Tab('Utilities'):
|
||||
with gr.Tab("Utilities"):
|
||||
utilities_tab(enable_dreambooth_tab=False, headless=headless)
|
||||
|
||||
# Show the interface
|
||||
launch_kwargs = {}
|
||||
username = kwargs.get('username')
|
||||
password = kwargs.get('password')
|
||||
server_port = kwargs.get('server_port', 0)
|
||||
inbrowser = kwargs.get('inbrowser', False)
|
||||
share = kwargs.get('share', False)
|
||||
server_name = kwargs.get('listen')
|
||||
username = kwargs.get("username")
|
||||
password = kwargs.get("password")
|
||||
server_port = kwargs.get("server_port", 0)
|
||||
inbrowser = kwargs.get("inbrowser", False)
|
||||
share = kwargs.get("share", False)
|
||||
server_name = kwargs.get("listen")
|
||||
|
||||
launch_kwargs['server_name'] = server_name
|
||||
launch_kwargs["server_name"] = server_name
|
||||
if username and password:
|
||||
launch_kwargs['auth'] = (username, password)
|
||||
launch_kwargs["auth"] = (username, password)
|
||||
if server_port > 0:
|
||||
launch_kwargs['server_port'] = server_port
|
||||
launch_kwargs["server_port"] = server_port
|
||||
if inbrowser:
|
||||
launch_kwargs['inbrowser'] = inbrowser
|
||||
launch_kwargs["inbrowser"] = inbrowser
|
||||
if share:
|
||||
launch_kwargs['share'] = share
|
||||
launch_kwargs["share"] = share
|
||||
interface.launch(**launch_kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--listen',
|
||||
"--listen",
|
||||
type=str,
|
||||
default='127.0.0.1',
|
||||
help='IP to listen on for connections to Gradio',
|
||||
default="127.0.0.1",
|
||||
help="IP to listen on for connections to Gradio",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--username', type=str, default='', help='Username for authentication'
|
||||
"--username", type=str, default="", help="Username for authentication"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--password', type=str, default='', help='Password for authentication'
|
||||
"--password", type=str, default="", help="Password for authentication"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--server_port',
|
||||
"--server_port",
|
||||
type=int,
|
||||
default=0,
|
||||
help='Port to run the server listener on',
|
||||
help="Port to run the server listener on",
|
||||
)
|
||||
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
||||
parser.add_argument("--share", action="store_true", help="Share the gradio UI")
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Is the server headless"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--inbrowser', action='store_true', help='Open in browser'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--share', action='store_true', help='Share the gradio UI'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--headless', action='store_true', help='Is the server headless'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--language', type=str, default=None, help='Set custom language'
|
||||
"--language", type=str, default=None, help="Set custom language"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -678,11 +678,11 @@ def get_str_or_default(kwargs, key, default_value=""):
|
|||
|
||||
def run_cmd_advanced_training(**kwargs):
|
||||
run_cmd = ""
|
||||
|
||||
|
||||
additional_parameters = kwargs.get("additional_parameters")
|
||||
if additional_parameters:
|
||||
run_cmd += f" {additional_parameters}"
|
||||
|
||||
|
||||
block_lr = kwargs.get("block_lr")
|
||||
if block_lr:
|
||||
run_cmd += f' --block_lr="(block_lr)"'
|
||||
|
|
@ -702,7 +702,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
cache_latents_to_disk = kwargs.get("cache_latents_to_disk")
|
||||
if cache_latents_to_disk:
|
||||
run_cmd += " --cache_latents_to_disk"
|
||||
|
||||
|
||||
cache_text_encoder_outputs = kwargs.get("cache_text_encoder_outputs")
|
||||
if cache_text_encoder_outputs:
|
||||
run_cmd += " --cache_text_encoder_outputs"
|
||||
|
|
@ -728,18 +728,18 @@ def run_cmd_advanced_training(**kwargs):
|
|||
color_aug = kwargs.get("color_aug")
|
||||
if color_aug:
|
||||
run_cmd += " --color_aug"
|
||||
|
||||
|
||||
dataset_repeats = kwargs.get("dataset_repeats")
|
||||
if dataset_repeats:
|
||||
run_cmd += f' --dataset_repeats="{dataset_repeats}"'
|
||||
|
||||
|
||||
enable_bucket = kwargs.get("enable_bucket")
|
||||
if enable_bucket:
|
||||
min_bucket_reso = kwargs.get("min_bucket_reso")
|
||||
max_bucket_reso = kwargs.get("max_bucket_reso")
|
||||
if min_bucket_reso and max_bucket_reso:
|
||||
run_cmd += f" --enable_bucket --min_bucket_reso={min_bucket_reso} --max_bucket_reso={max_bucket_reso}"
|
||||
|
||||
|
||||
in_json = kwargs.get("in_json")
|
||||
if in_json:
|
||||
run_cmd += f' --in_json="{in_json}"'
|
||||
|
|
@ -751,7 +751,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
fp8_base = kwargs.get("fp8_base")
|
||||
if fp8_base:
|
||||
run_cmd += " --fp8_base"
|
||||
|
||||
|
||||
full_bf16 = kwargs.get("full_bf16")
|
||||
if full_bf16:
|
||||
run_cmd += " --full_bf16"
|
||||
|
|
@ -759,7 +759,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
full_fp16 = kwargs.get("full_fp16")
|
||||
if full_fp16:
|
||||
run_cmd += " --full_fp16"
|
||||
|
||||
|
||||
gradient_accumulation_steps = kwargs.get("gradient_accumulation_steps")
|
||||
if gradient_accumulation_steps and int(gradient_accumulation_steps) > 1:
|
||||
run_cmd += f" --gradient_accumulation_steps={int(gradient_accumulation_steps)}"
|
||||
|
|
@ -775,19 +775,19 @@ def run_cmd_advanced_training(**kwargs):
|
|||
learning_rate = kwargs.get("learning_rate")
|
||||
if learning_rate:
|
||||
run_cmd += f' --learning_rate="{learning_rate}"'
|
||||
|
||||
|
||||
learning_rate_te = kwargs.get("learning_rate_te")
|
||||
if learning_rate_te:
|
||||
run_cmd += f' --learning_rate_te="{learning_rate_te}"'
|
||||
|
||||
|
||||
learning_rate_te1 = kwargs.get("learning_rate_te1")
|
||||
if learning_rate_te1:
|
||||
run_cmd += f' --learning_rate_te1="{learning_rate_te1}"'
|
||||
|
||||
|
||||
learning_rate_te2 = kwargs.get("learning_rate_te2")
|
||||
if learning_rate_te2:
|
||||
run_cmd += f' --learning_rate_te2="{learning_rate_te2}"'
|
||||
|
||||
|
||||
logging_dir = kwargs.get("logging_dir")
|
||||
if logging_dir:
|
||||
run_cmd += f' --logging_dir="{logging_dir}"'
|
||||
|
|
@ -799,7 +799,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
lr_scheduler_args = kwargs.get("lr_scheduler_args")
|
||||
if lr_scheduler_args and lr_scheduler_args != "":
|
||||
run_cmd += f" --lr_scheduler_args {lr_scheduler_args}"
|
||||
|
||||
|
||||
lr_scheduler_num_cycles = kwargs.get("lr_scheduler_num_cycles")
|
||||
if lr_scheduler_num_cycles and not lr_scheduler_num_cycles == "":
|
||||
run_cmd += f' --lr_scheduler_num_cycles="{lr_scheduler_num_cycles}"'
|
||||
|
|
@ -807,7 +807,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
epoch = kwargs.get("epoch")
|
||||
if epoch:
|
||||
run_cmd += f' --lr_scheduler_num_cycles="{epoch}"'
|
||||
|
||||
|
||||
lr_scheduler_power = kwargs.get("lr_scheduler_power")
|
||||
if lr_scheduler_power and not lr_scheduler_power == "":
|
||||
run_cmd += f' --lr_scheduler_power="{lr_scheduler_power}"'
|
||||
|
|
@ -830,7 +830,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
max_grad_norm = kwargs.get("max_grad_norm")
|
||||
if max_grad_norm and max_grad_norm != "":
|
||||
run_cmd += f' --max_grad_norm="{max_grad_norm}"'
|
||||
|
||||
|
||||
max_resolution = kwargs.get("max_resolution")
|
||||
if max_resolution:
|
||||
run_cmd += f' --resolution="{max_resolution}"'
|
||||
|
|
@ -844,7 +844,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
run_cmd += f" --max_token_length={int(max_token_length)}"
|
||||
|
||||
max_train_epochs = kwargs.get("max_train_epochs")
|
||||
if max_train_epochs and not max_train_epochs == '':
|
||||
if max_train_epochs and not max_train_epochs == "":
|
||||
run_cmd += f" --max_train_epochs={max_train_epochs}"
|
||||
|
||||
max_train_steps = kwargs.get("max_train_steps")
|
||||
|
|
@ -870,15 +870,15 @@ def run_cmd_advanced_training(**kwargs):
|
|||
multi_gpu = kwargs.get("multi_gpu")
|
||||
if multi_gpu:
|
||||
run_cmd += " --multi_gpu"
|
||||
|
||||
|
||||
no_half_vae = kwargs.get("no_half_vae")
|
||||
if no_half_vae:
|
||||
run_cmd += " --no_half_vae"
|
||||
|
||||
|
||||
no_token_padding = kwargs.get("no_token_padding")
|
||||
if no_token_padding:
|
||||
run_cmd += " --no_token_padding"
|
||||
|
||||
|
||||
noise_offset_type = kwargs.get("noise_offset_type")
|
||||
if noise_offset_type and noise_offset_type == "Original":
|
||||
noise_offset = kwargs.get("noise_offset")
|
||||
|
|
@ -886,17 +886,23 @@ def run_cmd_advanced_training(**kwargs):
|
|||
run_cmd += f" --noise_offset={float(noise_offset)}"
|
||||
|
||||
adaptive_noise_scale = kwargs.get("adaptive_noise_scale")
|
||||
if adaptive_noise_scale and float(adaptive_noise_scale) != 0 and float(noise_offset) > 0:
|
||||
if (
|
||||
adaptive_noise_scale
|
||||
and float(adaptive_noise_scale) != 0
|
||||
and float(noise_offset) > 0
|
||||
):
|
||||
run_cmd += f" --adaptive_noise_scale={float(adaptive_noise_scale)}"
|
||||
elif noise_offset_type and noise_offset_type == "Multires":
|
||||
multires_noise_iterations = kwargs.get("multires_noise_iterations")
|
||||
if int(multires_noise_iterations) > 0:
|
||||
run_cmd += f' --multires_noise_iterations="{int(multires_noise_iterations)}"'
|
||||
run_cmd += (
|
||||
f' --multires_noise_iterations="{int(multires_noise_iterations)}"'
|
||||
)
|
||||
|
||||
multires_noise_discount = kwargs.get("multires_noise_discount")
|
||||
if multires_noise_discount and float(multires_noise_discount) > 0:
|
||||
run_cmd += f' --multires_noise_discount="{float(multires_noise_discount)}"'
|
||||
|
||||
|
||||
num_machines = kwargs.get("num_machines")
|
||||
if num_machines and int(num_machines) > 1:
|
||||
run_cmd += f" --num_machines={int(num_machines)}"
|
||||
|
|
@ -916,12 +922,12 @@ def run_cmd_advanced_training(**kwargs):
|
|||
optimizer_type = kwargs.get("optimizer")
|
||||
if optimizer_type:
|
||||
run_cmd += f' --optimizer_type="{optimizer_type}"'
|
||||
|
||||
|
||||
output_dir = kwargs.get("output_dir")
|
||||
if output_dir:
|
||||
run_cmd += f' --output_dir="{output_dir}"'
|
||||
|
||||
output_name = kwargs.get("output_name")
|
||||
|
||||
output_name = kwargs.get("output_name")
|
||||
if output_name and not output_name == "":
|
||||
run_cmd += f' --output_name="{output_name}"'
|
||||
|
||||
|
|
@ -932,7 +938,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path")
|
||||
if pretrained_model_name_or_path:
|
||||
run_cmd += f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
||||
|
||||
|
||||
prior_loss_weight = kwargs.get("prior_loss_weight")
|
||||
if prior_loss_weight and not float(prior_loss_weight) == 1.0:
|
||||
run_cmd += f" --prior_loss_weight={prior_loss_weight}"
|
||||
|
|
@ -964,7 +970,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
save_last_n_steps_state = kwargs.get("save_last_n_steps_state")
|
||||
if save_last_n_steps_state and int(save_last_n_steps_state) > 0:
|
||||
run_cmd += f' --save_last_n_steps_state="{int(save_last_n_steps_state)}"'
|
||||
|
||||
|
||||
save_model_as = kwargs.get("save_model_as")
|
||||
if save_model_as and not save_model_as == "same as source model":
|
||||
run_cmd += f" --save_model_as={save_model_as}"
|
||||
|
|
@ -988,7 +994,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
shuffle_caption = kwargs.get("shuffle_caption")
|
||||
if shuffle_caption:
|
||||
run_cmd += " --shuffle_caption"
|
||||
|
||||
|
||||
stop_text_encoder_training = kwargs.get("stop_text_encoder_training")
|
||||
if stop_text_encoder_training and stop_text_encoder_training > 0:
|
||||
run_cmd += f' --stop_text_encoder_training="{stop_text_encoder_training}"'
|
||||
|
|
@ -996,11 +1002,11 @@ def run_cmd_advanced_training(**kwargs):
|
|||
train_batch_size = kwargs.get("train_batch_size")
|
||||
if train_batch_size:
|
||||
run_cmd += f' --train_batch_size="{train_batch_size}"'
|
||||
|
||||
|
||||
train_data_dir = kwargs.get("train_data_dir")
|
||||
if train_data_dir:
|
||||
run_cmd += f' --train_data_dir="{train_data_dir}"'
|
||||
|
||||
|
||||
train_text_encoder = kwargs.get("train_text_encoder")
|
||||
if train_text_encoder:
|
||||
run_cmd += " --train_text_encoder"
|
||||
|
|
@ -1008,7 +1014,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
use_wandb = kwargs.get("use_wandb")
|
||||
if use_wandb:
|
||||
run_cmd += " --log_with wandb"
|
||||
|
||||
|
||||
v_parameterization = kwargs.get("v_parameterization")
|
||||
if v_parameterization:
|
||||
run_cmd += " --v_parameterization"
|
||||
|
|
@ -1016,7 +1022,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
v_pred_like_loss = kwargs.get("v_pred_like_loss")
|
||||
if v_pred_like_loss and float(v_pred_like_loss) > 0:
|
||||
run_cmd += f' --v_pred_like_loss="{float(v_pred_like_loss)}"'
|
||||
|
||||
|
||||
v2 = kwargs.get("v2")
|
||||
if v2:
|
||||
run_cmd += " --v2"
|
||||
|
|
@ -1032,7 +1038,7 @@ def run_cmd_advanced_training(**kwargs):
|
|||
wandb_api_key = kwargs.get("wandb_api_key")
|
||||
if wandb_api_key:
|
||||
run_cmd += f' --wandb_api_key="{wandb_api_key}"'
|
||||
|
||||
|
||||
weighted_captions = kwargs.get("weighted_captions")
|
||||
if weighted_captions:
|
||||
run_cmd += " --weighted_captions"
|
||||
|
|
|
|||
11
lora_gui.py
11
lora_gui.py
|
|
@ -732,14 +732,15 @@ def train_model(
|
|||
log.info(f"lr_warmup_steps = {lr_warmup_steps}")
|
||||
|
||||
run_cmd = "accelerate launch"
|
||||
|
||||
|
||||
run_cmd += run_cmd_advanced_training(
|
||||
num_processes=num_processes,
|
||||
num_machines=num_machines,
|
||||
multi_gpu=multi_gpu,
|
||||
gpu_ids=gpu_ids,
|
||||
num_cpu_threads_per_process=num_cpu_threads_per_process)
|
||||
|
||||
num_cpu_threads_per_process=num_cpu_threads_per_process,
|
||||
)
|
||||
|
||||
if sdxl:
|
||||
run_cmd += f' "./sdxl_train_network.py"'
|
||||
else:
|
||||
|
|
@ -1803,7 +1804,9 @@ def lora_tab(
|
|||
placeholder="(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2",
|
||||
info="Specify the alpha of each block when expanding LoRA to Conv2d 3x3. Specify 25 numbers. If omitted, the value of conv_alpha is used.",
|
||||
)
|
||||
advanced_training = AdvancedTraining(headless=headless, training_type="lora")
|
||||
advanced_training = AdvancedTraining(
|
||||
headless=headless, training_type="lora"
|
||||
)
|
||||
advanced_training.color_aug.change(
|
||||
color_aug_changed,
|
||||
inputs=[advanced_training.color_aug],
|
||||
|
|
|
|||
|
|
@ -153,19 +153,19 @@ def save_configuration(
|
|||
|
||||
original_file_path = file_path
|
||||
|
||||
save_as_bool = True if save_as.get('label') == 'True' else False
|
||||
save_as_bool = True if save_as.get("label") == "True" else False
|
||||
|
||||
if save_as_bool:
|
||||
log.info('Save as...')
|
||||
log.info("Save as...")
|
||||
file_path = get_saveasfile_path(file_path)
|
||||
else:
|
||||
log.info('Save...')
|
||||
if file_path == None or file_path == '':
|
||||
log.info("Save...")
|
||||
if file_path == None or file_path == "":
|
||||
file_path = get_saveasfile_path(file_path)
|
||||
|
||||
# log.info(file_path)
|
||||
|
||||
if file_path == None or file_path == '':
|
||||
if file_path == None or file_path == "":
|
||||
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||
|
||||
# Extract the destination directory from the file path
|
||||
|
|
@ -178,7 +178,7 @@ def save_configuration(
|
|||
SaveConfigFile(
|
||||
parameters=parameters,
|
||||
file_path=file_path,
|
||||
exclusion=['file_path', 'save_as'],
|
||||
exclusion=["file_path", "save_as"],
|
||||
)
|
||||
|
||||
return file_path
|
||||
|
|
@ -282,18 +282,18 @@ def open_configuration(
|
|||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
||||
ask_for_file = True if ask_for_file.get('label') == 'True' else False
|
||||
ask_for_file = True if ask_for_file.get("label") == "True" else False
|
||||
|
||||
original_file_path = file_path
|
||||
|
||||
if ask_for_file:
|
||||
file_path = get_file_path(file_path)
|
||||
|
||||
if not file_path == '' and not file_path == None:
|
||||
if not file_path == "" and not file_path == None:
|
||||
# load variables from JSON file
|
||||
with open(file_path, 'r') as f:
|
||||
with open(file_path, "r") as f:
|
||||
my_data = json.load(f)
|
||||
log.info('Loading config...')
|
||||
log.info("Loading config...")
|
||||
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
|
||||
my_data = update_my_data(my_data)
|
||||
else:
|
||||
|
|
@ -303,7 +303,7 @@ def open_configuration(
|
|||
values = [file_path]
|
||||
for key, value in parameters:
|
||||
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
||||
if not key in ['ask_for_file', 'file_path']:
|
||||
if not key in ["ask_for_file", "file_path"]:
|
||||
values.append(my_data.get(key, value))
|
||||
return tuple(values)
|
||||
|
||||
|
|
@ -406,36 +406,32 @@ def train_model(
|
|||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
||||
print_only_bool = True if print_only.get('label') == 'True' else False
|
||||
log.info(f'Start training TI...')
|
||||
print_only_bool = True if print_only.get("label") == "True" else False
|
||||
log.info(f"Start training TI...")
|
||||
|
||||
headless_bool = True if headless.get('label') == 'True' else False
|
||||
headless_bool = True if headless.get("label") == "True" else False
|
||||
|
||||
if pretrained_model_name_or_path == '':
|
||||
if pretrained_model_name_or_path == "":
|
||||
output_message(
|
||||
msg='Source model information is missing', headless=headless_bool
|
||||
msg="Source model information is missing", headless=headless_bool
|
||||
)
|
||||
return
|
||||
|
||||
if train_data_dir == '':
|
||||
output_message(
|
||||
msg='Image folder path is missing', headless=headless_bool
|
||||
)
|
||||
if train_data_dir == "":
|
||||
output_message(msg="Image folder path is missing", headless=headless_bool)
|
||||
return
|
||||
|
||||
if not os.path.exists(train_data_dir):
|
||||
output_message(
|
||||
msg='Image folder does not exist', headless=headless_bool
|
||||
)
|
||||
output_message(msg="Image folder does not exist", headless=headless_bool)
|
||||
return
|
||||
|
||||
if not verify_image_folder_pattern(train_data_dir):
|
||||
return
|
||||
|
||||
if reg_data_dir != '':
|
||||
if reg_data_dir != "":
|
||||
if not os.path.exists(reg_data_dir):
|
||||
output_message(
|
||||
msg='Regularisation folder does not exist',
|
||||
msg="Regularisation folder does not exist",
|
||||
headless=headless_bool,
|
||||
)
|
||||
return
|
||||
|
|
@ -443,26 +439,22 @@ def train_model(
|
|||
if not verify_image_folder_pattern(reg_data_dir):
|
||||
return
|
||||
|
||||
if output_dir == '':
|
||||
output_message(
|
||||
msg='Output folder path is missing', headless=headless_bool
|
||||
)
|
||||
if output_dir == "":
|
||||
output_message(msg="Output folder path is missing", headless=headless_bool)
|
||||
return
|
||||
|
||||
if token_string == '':
|
||||
output_message(msg='Token string is missing', headless=headless_bool)
|
||||
if token_string == "":
|
||||
output_message(msg="Token string is missing", headless=headless_bool)
|
||||
return
|
||||
|
||||
if init_word == '':
|
||||
output_message(msg='Init word is missing', headless=headless_bool)
|
||||
if init_word == "":
|
||||
output_message(msg="Init word is missing", headless=headless_bool)
|
||||
return
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
if check_if_model_exist(
|
||||
output_name, output_dir, save_model_as, headless_bool
|
||||
):
|
||||
if check_if_model_exist(output_name, output_dir, save_model_as, headless_bool):
|
||||
return
|
||||
|
||||
# if float(noise_offset) > 0 and (
|
||||
|
|
@ -495,7 +487,7 @@ def train_model(
|
|||
# Loop through each subfolder and extract the number of repeats
|
||||
for folder in subfolders:
|
||||
# Extract the number of repeats from the folder name
|
||||
repeats = int(folder.split('_')[0])
|
||||
repeats = int(folder.split("_")[0])
|
||||
|
||||
# Count the number of images in the folder
|
||||
num_images = len(
|
||||
|
|
@ -503,11 +495,9 @@ def train_model(
|
|||
f
|
||||
for f, lower_f in (
|
||||
(file, file.lower())
|
||||
for file in os.listdir(
|
||||
os.path.join(train_data_dir, folder)
|
||||
)
|
||||
for file in os.listdir(os.path.join(train_data_dir, folder))
|
||||
)
|
||||
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
||||
if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp"))
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -516,21 +506,21 @@ def train_model(
|
|||
total_steps += steps
|
||||
|
||||
# Print the result
|
||||
log.info(f'Folder {folder}: {steps} steps')
|
||||
log.info(f"Folder {folder}: {steps} steps")
|
||||
|
||||
# Print the result
|
||||
# log.info(f"{total_steps} total steps")
|
||||
|
||||
if reg_data_dir == '':
|
||||
if reg_data_dir == "":
|
||||
reg_factor = 1
|
||||
else:
|
||||
log.info(
|
||||
'Regularisation images are used... Will double the number of steps required...'
|
||||
"Regularisation images are used... Will double the number of steps required..."
|
||||
)
|
||||
reg_factor = 2
|
||||
|
||||
# calculate max_train_steps
|
||||
if max_train_steps == '' or max_train_steps == '0':
|
||||
if max_train_steps == "" or max_train_steps == "0":
|
||||
max_train_steps = int(
|
||||
math.ceil(
|
||||
float(total_steps)
|
||||
|
|
@ -543,7 +533,7 @@ def train_model(
|
|||
else:
|
||||
max_train_steps = int(max_train_steps)
|
||||
|
||||
log.info(f'max_train_steps = {max_train_steps}')
|
||||
log.info(f"max_train_steps = {max_train_steps}")
|
||||
|
||||
# calculate stop encoder training
|
||||
if stop_text_encoder_training_pct == None:
|
||||
|
|
@ -552,20 +542,21 @@ def train_model(
|
|||
stop_text_encoder_training = math.ceil(
|
||||
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
|
||||
)
|
||||
log.info(f'stop_text_encoder_training = {stop_text_encoder_training}')
|
||||
log.info(f"stop_text_encoder_training = {stop_text_encoder_training}")
|
||||
|
||||
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
||||
log.info(f'lr_warmup_steps = {lr_warmup_steps}')
|
||||
log.info(f"lr_warmup_steps = {lr_warmup_steps}")
|
||||
|
||||
run_cmd = "accelerate launch"
|
||||
|
||||
|
||||
run_cmd += run_cmd_advanced_training(
|
||||
num_processes=num_processes,
|
||||
num_machines=num_machines,
|
||||
multi_gpu=multi_gpu,
|
||||
gpu_ids=gpu_ids,
|
||||
num_cpu_threads_per_process=num_cpu_threads_per_process)
|
||||
|
||||
num_cpu_threads_per_process=num_cpu_threads_per_process,
|
||||
)
|
||||
|
||||
if sdxl:
|
||||
run_cmd += f' "./sdxl_train_textual_inversion.py"'
|
||||
else:
|
||||
|
|
@ -649,13 +640,13 @@ def train_model(
|
|||
)
|
||||
run_cmd += f' --token_string="{token_string}"'
|
||||
run_cmd += f' --init_word="{init_word}"'
|
||||
run_cmd += f' --num_vectors_per_token={num_vectors_per_token}'
|
||||
if not weights == '':
|
||||
run_cmd += f" --num_vectors_per_token={num_vectors_per_token}"
|
||||
if not weights == "":
|
||||
run_cmd += f' --weights="{weights}"'
|
||||
if template == 'object template':
|
||||
run_cmd += f' --use_object_template'
|
||||
elif template == 'style template':
|
||||
run_cmd += f' --use_style_template'
|
||||
if template == "object template":
|
||||
run_cmd += f" --use_object_template"
|
||||
elif template == "style template":
|
||||
run_cmd += f" --use_style_template"
|
||||
|
||||
run_cmd += run_cmd_sample(
|
||||
sample_every_n_steps,
|
||||
|
|
@ -667,7 +658,7 @@ def train_model(
|
|||
|
||||
if print_only_bool:
|
||||
log.warning(
|
||||
'Here is the trainer command as a reference. It will not be executed:\n'
|
||||
"Here is the trainer command as a reference. It will not be executed:\n"
|
||||
)
|
||||
print(run_cmd)
|
||||
|
||||
|
|
@ -675,17 +666,15 @@ def train_model(
|
|||
else:
|
||||
# Saving config file for model
|
||||
current_datetime = datetime.now()
|
||||
formatted_datetime = current_datetime.strftime('%Y%m%d-%H%M%S')
|
||||
file_path = os.path.join(
|
||||
output_dir, f'{output_name}_{formatted_datetime}.json'
|
||||
)
|
||||
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
|
||||
file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json")
|
||||
|
||||
log.info(f'Saving training config to {file_path}...')
|
||||
log.info(f"Saving training config to {file_path}...")
|
||||
|
||||
SaveConfigFile(
|
||||
parameters=parameters,
|
||||
file_path=file_path,
|
||||
exclusion=['file_path', 'save_as', 'headless', 'print_only'],
|
||||
exclusion=["file_path", "save_as", "headless", "print_only"],
|
||||
)
|
||||
|
||||
log.info(run_cmd)
|
||||
|
|
@ -695,13 +684,11 @@ def train_model(
|
|||
executor.execute_command(run_cmd=run_cmd)
|
||||
|
||||
# check if output_dir/last is a folder... therefore it is a diffuser model
|
||||
last_dir = pathlib.Path(f'{output_dir}/{output_name}')
|
||||
last_dir = pathlib.Path(f"{output_dir}/{output_name}")
|
||||
|
||||
if not last_dir.is_dir():
|
||||
# Copy inference model for v2 if required
|
||||
save_inference_file(
|
||||
output_dir, v2, v_parameterization, output_name
|
||||
)
|
||||
save_inference_file(output_dir, v2, v_parameterization, output_name)
|
||||
|
||||
|
||||
def ti_tab(
|
||||
|
|
@ -711,32 +698,32 @@ def ti_tab(
|
|||
dummy_db_false = gr.Label(value=False, visible=False)
|
||||
dummy_headless = gr.Label(value=headless, visible=False)
|
||||
|
||||
with gr.Tab('Training'):
|
||||
gr.Markdown('Train a TI using kohya textual inversion python code...')
|
||||
with gr.Tab("Training"):
|
||||
gr.Markdown("Train a TI using kohya textual inversion python code...")
|
||||
|
||||
# Setup Configuration Files Gradio
|
||||
config = ConfigurationFile(headless)
|
||||
|
||||
source_model = SourceModel(
|
||||
save_model_as_choices=[
|
||||
'ckpt',
|
||||
'safetensors',
|
||||
"ckpt",
|
||||
"safetensors",
|
||||
],
|
||||
headless=headless,
|
||||
)
|
||||
|
||||
with gr.Tab('Folders'):
|
||||
with gr.Tab("Folders"):
|
||||
folders = Folders(headless=headless)
|
||||
with gr.Tab('Parameters'):
|
||||
with gr.Tab('Basic', elem_id='basic_tab'):
|
||||
with gr.Tab("Parameters"):
|
||||
with gr.Tab("Basic", elem_id="basic_tab"):
|
||||
with gr.Row():
|
||||
weights = gr.Textbox(
|
||||
label='Resume TI training',
|
||||
placeholder='(Optional) Path to existing TI embeding file to keep training',
|
||||
label="Resume TI training",
|
||||
placeholder="(Optional) Path to existing TI embeding file to keep training",
|
||||
)
|
||||
weights_file_input = gr.Button(
|
||||
'📂',
|
||||
elem_id='open_folder_small',
|
||||
"",
|
||||
elem_id="open_folder_small",
|
||||
visible=(not headless),
|
||||
)
|
||||
weights_file_input.click(
|
||||
|
|
@ -746,37 +733,37 @@ def ti_tab(
|
|||
)
|
||||
with gr.Row():
|
||||
token_string = gr.Textbox(
|
||||
label='Token string',
|
||||
placeholder='eg: cat',
|
||||
label="Token string",
|
||||
placeholder="eg: cat",
|
||||
)
|
||||
init_word = gr.Textbox(
|
||||
label='Init word',
|
||||
value='*',
|
||||
label="Init word",
|
||||
value="*",
|
||||
)
|
||||
num_vectors_per_token = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=75,
|
||||
value=1,
|
||||
step=1,
|
||||
label='Vectors',
|
||||
label="Vectors",
|
||||
)
|
||||
# max_train_steps = gr.Textbox(
|
||||
# label='Max train steps',
|
||||
# placeholder='(Optional) Maximum number of steps',
|
||||
# )
|
||||
template = gr.Dropdown(
|
||||
label='Template',
|
||||
label="Template",
|
||||
choices=[
|
||||
'caption',
|
||||
'object template',
|
||||
'style template',
|
||||
"caption",
|
||||
"object template",
|
||||
"style template",
|
||||
],
|
||||
value='caption',
|
||||
value="caption",
|
||||
)
|
||||
basic_training = BasicTraining(
|
||||
learning_rate_value='1e-5',
|
||||
lr_scheduler_value='cosine',
|
||||
lr_warmup_value='10',
|
||||
learning_rate_value="1e-5",
|
||||
lr_scheduler_value="cosine",
|
||||
lr_warmup_value="10",
|
||||
sdxl_checkbox=source_model.sdxl_checkbox,
|
||||
)
|
||||
|
||||
|
|
@ -786,7 +773,7 @@ def ti_tab(
|
|||
show_sdxl_cache_text_encoder_outputs=False,
|
||||
)
|
||||
|
||||
with gr.Tab('Advanced', elem_id='advanced_tab'):
|
||||
with gr.Tab("Advanced", elem_id="advanced_tab"):
|
||||
advanced_training = AdvancedTraining(headless=headless)
|
||||
advanced_training.color_aug.change(
|
||||
color_aug_changed,
|
||||
|
|
@ -794,12 +781,12 @@ def ti_tab(
|
|||
outputs=[basic_training.cache_latents],
|
||||
)
|
||||
|
||||
with gr.Tab('Samples', elem_id='samples_tab'):
|
||||
with gr.Tab("Samples", elem_id="samples_tab"):
|
||||
sample = SampleImages()
|
||||
|
||||
with gr.Tab('Dataset Preparation'):
|
||||
with gr.Tab("Dataset Preparation"):
|
||||
gr.Markdown(
|
||||
'This section provide Dreambooth tools to help setup your dataset...'
|
||||
"This section provide Dreambooth tools to help setup your dataset..."
|
||||
)
|
||||
gradio_dreambooth_folder_creation_tab(
|
||||
train_data_dir_input=folders.train_data_dir,
|
||||
|
|
@ -811,11 +798,11 @@ def ti_tab(
|
|||
gradio_dataset_balancing_tab(headless=headless)
|
||||
|
||||
with gr.Row():
|
||||
button_run = gr.Button('Start training', variant='primary')
|
||||
button_run = gr.Button("Start training", variant="primary")
|
||||
|
||||
button_stop_training = gr.Button('Stop training')
|
||||
button_stop_training = gr.Button("Stop training")
|
||||
|
||||
button_print = gr.Button('Print training command')
|
||||
button_print = gr.Button("Print training command")
|
||||
|
||||
# Setup gradio tensorboard buttons
|
||||
(
|
||||
|
|
@ -978,30 +965,28 @@ def ti_tab(
|
|||
|
||||
|
||||
def UI(**kwargs):
|
||||
add_javascript(kwargs.get('language'))
|
||||
css = ''
|
||||
add_javascript(kwargs.get("language"))
|
||||
css = ""
|
||||
|
||||
headless = kwargs.get('headless', False)
|
||||
log.info(f'headless: {headless}')
|
||||
headless = kwargs.get("headless", False)
|
||||
log.info(f"headless: {headless}")
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
log.info('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
if os.path.exists("./style.css"):
|
||||
with open(os.path.join("./style.css"), "r", encoding="utf8") as file:
|
||||
log.info("Load CSS...")
|
||||
css += file.read() + "\n"
|
||||
|
||||
interface = gr.Blocks(
|
||||
css=css, title='Kohya_ss GUI', theme=gr.themes.Default()
|
||||
)
|
||||
interface = gr.Blocks(css=css, title="Kohya_ss GUI", theme=gr.themes.Default())
|
||||
|
||||
with interface:
|
||||
with gr.Tab('Dreambooth TI'):
|
||||
with gr.Tab("Dreambooth TI"):
|
||||
(
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
) = ti_tab(headless=headless)
|
||||
with gr.Tab('Utilities'):
|
||||
with gr.Tab("Utilities"):
|
||||
utilities_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
reg_data_dir_input=reg_data_dir_input,
|
||||
|
|
@ -1013,57 +998,53 @@ def UI(**kwargs):
|
|||
|
||||
# Show the interface
|
||||
launch_kwargs = {}
|
||||
username = kwargs.get('username')
|
||||
password = kwargs.get('password')
|
||||
server_port = kwargs.get('server_port', 0)
|
||||
inbrowser = kwargs.get('inbrowser', False)
|
||||
share = kwargs.get('share', False)
|
||||
server_name = kwargs.get('listen')
|
||||
username = kwargs.get("username")
|
||||
password = kwargs.get("password")
|
||||
server_port = kwargs.get("server_port", 0)
|
||||
inbrowser = kwargs.get("inbrowser", False)
|
||||
share = kwargs.get("share", False)
|
||||
server_name = kwargs.get("listen")
|
||||
|
||||
launch_kwargs['server_name'] = server_name
|
||||
launch_kwargs["server_name"] = server_name
|
||||
if username and password:
|
||||
launch_kwargs['auth'] = (username, password)
|
||||
launch_kwargs["auth"] = (username, password)
|
||||
if server_port > 0:
|
||||
launch_kwargs['server_port'] = server_port
|
||||
launch_kwargs["server_port"] = server_port
|
||||
if inbrowser:
|
||||
launch_kwargs['inbrowser'] = inbrowser
|
||||
launch_kwargs["inbrowser"] = inbrowser
|
||||
if share:
|
||||
launch_kwargs['share'] = share
|
||||
launch_kwargs["share"] = share
|
||||
interface.launch(**launch_kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--listen',
|
||||
"--listen",
|
||||
type=str,
|
||||
default='127.0.0.1',
|
||||
help='IP to listen on for connections to Gradio',
|
||||
default="127.0.0.1",
|
||||
help="IP to listen on for connections to Gradio",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--username', type=str, default='', help='Username for authentication'
|
||||
"--username", type=str, default="", help="Username for authentication"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--password', type=str, default='', help='Password for authentication'
|
||||
"--password", type=str, default="", help="Password for authentication"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--server_port',
|
||||
"--server_port",
|
||||
type=int,
|
||||
default=0,
|
||||
help='Port to run the server listener on',
|
||||
help="Port to run the server listener on",
|
||||
)
|
||||
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
||||
parser.add_argument("--share", action="store_true", help="Share the gradio UI")
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Is the server headless"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--inbrowser', action='store_true', help='Open in browser'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--share', action='store_true', help='Share the gradio UI'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--headless', action='store_true', help='Is the server headless'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--language', type=str, default=None, help='Set custom language'
|
||||
"--language", type=str, default=None, help="Set custom language"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
Loading…
Reference in New Issue