Format code

pull/1966/head^2
bmaltais 2024-02-03 00:04:00 -05:00
parent 08ce96f33b
commit 0b217a4cf8
5 changed files with 348 additions and 379 deletions

View File

@ -551,14 +551,15 @@ def train_model(
# run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"'
run_cmd = "accelerate launch"
run_cmd += run_cmd_advanced_training(
num_processes=num_processes,
num_machines=num_machines,
multi_gpu=multi_gpu,
gpu_ids=gpu_ids,
num_cpu_threads_per_process=num_cpu_threads_per_process)
num_cpu_threads_per_process=num_cpu_threads_per_process,
)
if sdxl:
run_cmd += f' "./sdxl_train.py"'
else:

View File

@ -43,12 +43,12 @@ executor = CommandExecutor()
# from easygui import msgbox
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
folder_symbol = "\U0001f4c2" # 📂
refresh_symbol = "\U0001f504" # 🔄
save_style_symbol = "\U0001f4be" # 💾
document_symbol = "\U0001F4C4" # 📄
PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
PYTHON = "python3" if os.name == "posix" else "./venv/Scripts/python.exe"
def save_configuration(
@ -100,7 +100,8 @@ def save_configuration(
save_state,
resume,
gradient_checkpointing,
gradient_accumulation_steps,block_lr,
gradient_accumulation_steps,
block_lr,
mem_eff_attn,
shuffle_caption,
output_name,
@ -153,19 +154,19 @@ def save_configuration(
original_file_path = file_path
save_as_bool = True if save_as.get('label') == 'True' else False
save_as_bool = True if save_as.get("label") == "True" else False
if save_as_bool:
log.info('Save as...')
log.info("Save as...")
file_path = get_saveasfile_path(file_path)
else:
log.info('Save...')
if file_path == None or file_path == '':
log.info("Save...")
if file_path == None or file_path == "":
file_path = get_saveasfile_path(file_path)
# log.info(file_path)
if file_path == None or file_path == '':
if file_path == None or file_path == "":
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
# Extract the destination directory from the file path
@ -178,7 +179,7 @@ def save_configuration(
SaveConfigFile(
parameters=parameters,
file_path=file_path,
exclusion=['file_path', 'save_as'],
exclusion=["file_path", "save_as"],
)
return file_path
@ -234,7 +235,8 @@ def open_configuration(
save_state,
resume,
gradient_checkpointing,
gradient_accumulation_steps,block_lr,
gradient_accumulation_steps,
block_lr,
mem_eff_attn,
shuffle_caption,
output_name,
@ -286,33 +288,31 @@ def open_configuration(
# Get list of function parameters and values
parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False
apply_preset = True if apply_preset.get('label') == 'True' else False
ask_for_file = True if ask_for_file.get("label") == "True" else False
apply_preset = True if apply_preset.get("label") == "True" else False
# Check if we are "applying" a preset or a config
if apply_preset:
log.info(f'Applying preset {training_preset}...')
file_path = f'./presets/finetune/{training_preset}.json'
log.info(f"Applying preset {training_preset}...")
file_path = f"./presets/finetune/{training_preset}.json"
else:
# If not applying a preset, set the `training_preset` field to an empty string
# Find the index of the `training_preset` parameter using the `index()` method
training_preset_index = parameters.index(
('training_preset', training_preset)
)
training_preset_index = parameters.index(("training_preset", training_preset))
# Update the value of `training_preset` by directly assigning an empty string value
parameters[training_preset_index] = ('training_preset', '')
parameters[training_preset_index] = ("training_preset", "")
original_file_path = file_path
if ask_for_file:
file_path = get_file_path(file_path)
if not file_path == '' and not file_path == None:
if not file_path == "" and not file_path == None:
# load variables from JSON file
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
my_data = json.load(f)
log.info('Loading config...')
log.info("Loading config...")
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
my_data = update_my_data(my_data)
else:
@ -323,7 +323,7 @@ def open_configuration(
for key, value in parameters:
json_value = my_data.get(key)
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['ask_for_file', 'apply_preset', 'file_path']:
if not key in ["ask_for_file", "apply_preset", "file_path"]:
values.append(json_value if json_value is not None else value)
return tuple(values)
@ -377,7 +377,8 @@ def train_model(
save_state,
resume,
gradient_checkpointing,
gradient_accumulation_steps,block_lr,
gradient_accumulation_steps,
block_lr,
mem_eff_attn,
shuffle_caption,
output_name,
@ -428,14 +429,12 @@ def train_model(
# Get list of function parameters and values
parameters = list(locals().items())
print_only_bool = True if print_only.get('label') == 'True' else False
log.info(f'Start Finetuning...')
print_only_bool = True if print_only.get("label") == "True" else False
log.info(f"Start Finetuning...")
headless_bool = True if headless.get('label') == 'True' else False
headless_bool = True if headless.get("label") == "True" else False
if check_if_model_exist(
output_name, output_dir, save_model_as, headless_bool
):
if check_if_model_exist(output_name, output_dir, save_model_as, headless_bool):
return
# if float(noise_offset) > 0 and (
@ -461,52 +460,50 @@ def train_model(
if not os.path.exists(train_dir):
os.mkdir(train_dir)
run_cmd = f'{PYTHON} finetune/merge_captions_to_metadata.py'
if caption_extension == '':
run_cmd = f"{PYTHON} finetune/merge_captions_to_metadata.py"
if caption_extension == "":
run_cmd += f' --caption_extension=".caption"'
else:
run_cmd += f' --caption_extension={caption_extension}'
run_cmd += f" --caption_extension={caption_extension}"
run_cmd += f' "{image_folder}"'
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
if full_path:
run_cmd += f' --full_path'
run_cmd += f" --full_path"
log.info(run_cmd)
if not print_only_bool:
# Run the command
if os.name == 'posix':
if os.name == "posix":
os.system(run_cmd)
else:
subprocess.run(run_cmd)
# create images buckets
if generate_image_buckets:
run_cmd = f'{PYTHON} finetune/prepare_buckets_latents.py'
run_cmd = f"{PYTHON} finetune/prepare_buckets_latents.py"
run_cmd += f' "{image_folder}"'
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
run_cmd += f' "{train_dir}/{latent_metadata_filename}"'
run_cmd += f' "{pretrained_model_name_or_path}"'
run_cmd += f' --batch_size={batch_size}'
run_cmd += f' --max_resolution={max_resolution}'
run_cmd += f' --min_bucket_reso={min_bucket_reso}'
run_cmd += f' --max_bucket_reso={max_bucket_reso}'
run_cmd += f' --mixed_precision={mixed_precision}'
run_cmd += f" --batch_size={batch_size}"
run_cmd += f" --max_resolution={max_resolution}"
run_cmd += f" --min_bucket_reso={min_bucket_reso}"
run_cmd += f" --max_bucket_reso={max_bucket_reso}"
run_cmd += f" --mixed_precision={mixed_precision}"
# if flip_aug:
# run_cmd += f' --flip_aug'
if full_path:
run_cmd += f' --full_path'
run_cmd += f" --full_path"
if sdxl_checkbox and sdxl_no_half_vae:
log.info(
'Using mixed_precision = no because no half vae is selected...'
)
log.info("Using mixed_precision = no because no half vae is selected...")
run_cmd += f' --mixed_precision="no"'
log.info(run_cmd)
if not print_only_bool:
# Run the command
if os.name == 'posix':
if os.name == "posix":
os.system(run_cmd)
else:
subprocess.run(run_cmd)
@ -517,13 +514,13 @@ def train_model(
for f, lower_f in (
(file, file.lower()) for file in os.listdir(image_folder)
)
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp"))
]
)
log.info(f'image_num = {image_num}')
log.info(f"image_num = {image_num}")
repeats = int(image_num) * int(dataset_repeats)
log.info(f'repeats = {str(repeats)}')
log.info(f"repeats = {str(repeats)}")
# calculate max_train_steps
max_train_steps = int(
@ -539,26 +536,31 @@ def train_model(
if flip_aug:
max_train_steps = int(math.ceil(float(max_train_steps) / 2))
log.info(f'max_train_steps = {max_train_steps}')
log.info(f"max_train_steps = {max_train_steps}")
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
log.info(f'lr_warmup_steps = {lr_warmup_steps}')
log.info(f"lr_warmup_steps = {lr_warmup_steps}")
run_cmd = "accelerate launch"
run_cmd += run_cmd_advanced_training(
num_processes=num_processes,
num_machines=num_machines,
multi_gpu=multi_gpu,
gpu_ids=gpu_ids,
num_cpu_threads_per_process=num_cpu_threads_per_process)
num_cpu_threads_per_process=num_cpu_threads_per_process,
)
if sdxl_checkbox:
run_cmd += f' "./sdxl_train.py"'
else:
run_cmd += f' "./fine_tune.py"'
in_json = f'{train_dir}/{latent_metadata_filename}' if use_latent_files == 'Yes' else f'{train_dir}/{caption_metadata_filename}'
in_json = (
f"{train_dir}/{latent_metadata_filename}"
if use_latent_files == "Yes"
else f"{train_dir}/{caption_metadata_filename}"
)
cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs
no_half_vae = sdxl_checkbox and sdxl_no_half_vae
@ -570,7 +572,9 @@ def train_model(
bucket_reso_steps=bucket_reso_steps,
cache_latents=cache_latents,
cache_latents_to_disk=cache_latents_to_disk,
cache_text_encoder_outputs=cache_text_encoder_outputs if sdxl_checkbox else None,
cache_text_encoder_outputs=cache_text_encoder_outputs
if sdxl_checkbox
else None,
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
caption_dropout_rate=caption_dropout_rate,
caption_extension=caption_extension,
@ -651,7 +655,7 @@ def train_model(
if print_only_bool:
log.warning(
'Here is the trainer command as a reference. It will not be executed:\n'
"Here is the trainer command as a reference. It will not be executed:\n"
)
print(run_cmd)
@ -659,17 +663,15 @@ def train_model(
else:
# Saving config file for model
current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime('%Y%m%d-%H%M%S')
file_path = os.path.join(
output_dir, f'{output_name}_{formatted_datetime}.json'
)
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json")
log.info(f'Saving training config to {file_path}...')
log.info(f"Saving training config to {file_path}...")
SaveConfigFile(
parameters=parameters,
file_path=file_path,
exclusion=['file_path', 'save_as', 'headless', 'print_only'],
exclusion=["file_path", "save_as", "headless", "print_only"],
)
log.info(run_cmd)
@ -678,18 +680,16 @@ def train_model(
executor.execute_command(run_cmd=run_cmd)
# check if output_dir/last is a folder... therefore it is a diffuser model
last_dir = pathlib.Path(f'{output_dir}/{output_name}')
last_dir = pathlib.Path(f"{output_dir}/{output_name}")
if not last_dir.is_dir():
# Copy inference model for v2 if required
save_inference_file(
output_dir, v2, v_parameterization, output_name
)
save_inference_file(output_dir, v2, v_parameterization, output_name)
def remove_doublequote(file_path):
if file_path != None:
file_path = file_path.replace('"', '')
file_path = file_path.replace('"', "")
return file_path
@ -698,23 +698,23 @@ def finetune_tab(headless=False):
dummy_db_true = gr.Label(value=True, visible=False)
dummy_db_false = gr.Label(value=False, visible=False)
dummy_headless = gr.Label(value=headless, visible=False)
with gr.Tab('Training'):
gr.Markdown('Train a custom model using kohya finetune python code...')
with gr.Tab("Training"):
gr.Markdown("Train a custom model using kohya finetune python code...")
# Setup Configuration Files Gradio
config = ConfigurationFile(headless)
source_model = SourceModel(headless=headless)
with gr.Tab('Folders'):
with gr.Tab("Folders"):
with gr.Row():
train_dir = gr.Textbox(
label='Training config folder',
placeholder='folder where the training configuration files will be saved',
label="Training config folder",
placeholder="folder where the training configuration files will be saved",
)
train_dir_folder = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_id="open_folder_small",
visible=(not headless),
)
train_dir_folder.click(
@ -724,12 +724,12 @@ def finetune_tab(headless=False):
)
image_folder = gr.Textbox(
label='Training Image folder',
placeholder='folder where the training images are located',
label="Training Image folder",
placeholder="folder where the training images are located",
)
image_folder_input_folder = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_id="open_folder_small",
visible=(not headless),
)
image_folder_input_folder.click(
@ -739,12 +739,12 @@ def finetune_tab(headless=False):
)
with gr.Row():
output_dir = gr.Textbox(
label='Model output folder',
placeholder='folder where the model will be saved',
label="Model output folder",
placeholder="folder where the model will be saved",
)
output_dir_input_folder = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_id="open_folder_small",
visible=(not headless),
)
output_dir_input_folder.click(
@ -754,12 +754,12 @@ def finetune_tab(headless=False):
)
logging_dir = gr.Textbox(
label='Logging folder',
placeholder='Optional: enable logging and output TensorBoard log to this folder',
label="Logging folder",
placeholder="Optional: enable logging and output TensorBoard log to this folder",
)
logging_dir_input_folder = gr.Button(
folder_symbol,
elem_id='open_folder_small',
elem_id="open_folder_small",
visible=(not headless),
)
logging_dir_input_folder.click(
@ -769,9 +769,9 @@ def finetune_tab(headless=False):
)
with gr.Row():
output_name = gr.Textbox(
label='Model output name',
placeholder='Name of the model to output',
value='last',
label="Model output name",
placeholder="Name of the model to output",
value="last",
interactive=True,
)
train_dir.change(
@ -789,102 +789,96 @@ def finetune_tab(headless=False):
inputs=[output_dir],
outputs=[output_dir],
)
with gr.Tab('Dataset preparation'):
with gr.Tab("Dataset preparation"):
with gr.Row():
max_resolution = gr.Textbox(
label='Resolution (width,height)', value='512,512'
)
min_bucket_reso = gr.Textbox(
label='Min bucket resolution', value='256'
label="Resolution (width,height)", value="512,512"
)
min_bucket_reso = gr.Textbox(label="Min bucket resolution", value="256")
max_bucket_reso = gr.Textbox(
label='Max bucket resolution', value='1024'
label="Max bucket resolution", value="1024"
)
batch_size = gr.Textbox(label='Batch size', value='1')
batch_size = gr.Textbox(label="Batch size", value="1")
with gr.Row():
create_caption = gr.Checkbox(
label='Generate caption metadata', value=True
label="Generate caption metadata", value=True
)
create_buckets = gr.Checkbox(
label='Generate image buckets metadata', value=True
label="Generate image buckets metadata", value=True
)
use_latent_files = gr.Dropdown(
label='Use latent files',
label="Use latent files",
choices=[
'No',
'Yes',
"No",
"Yes",
],
value='Yes',
value="Yes",
)
with gr.Accordion('Advanced parameters', open=False):
with gr.Accordion("Advanced parameters", open=False):
with gr.Row():
caption_metadata_filename = gr.Textbox(
label='Caption metadata filename',
value='meta_cap.json',
label="Caption metadata filename",
value="meta_cap.json",
)
latent_metadata_filename = gr.Textbox(
label='Latent metadata filename', value='meta_lat.json'
label="Latent metadata filename", value="meta_lat.json"
)
with gr.Row():
full_path = gr.Checkbox(label='Use full path', value=True)
full_path = gr.Checkbox(label="Use full path", value=True)
weighted_captions = gr.Checkbox(
label='Weighted captions', value=False
label="Weighted captions", value=False
)
with gr.Tab('Parameters'):
with gr.Tab("Parameters"):
def list_presets(path):
json_files = []
for file in os.listdir(path):
if file.endswith('.json'):
if file.endswith(".json"):
json_files.append(os.path.splitext(file)[0])
user_presets_path = os.path.join(path, 'user_presets')
user_presets_path = os.path.join(path, "user_presets")
if os.path.isdir(user_presets_path):
for file in os.listdir(user_presets_path):
if file.endswith('.json'):
if file.endswith(".json"):
preset_name = os.path.splitext(file)[0]
json_files.append(
os.path.join('user_presets', preset_name)
)
json_files.append(os.path.join("user_presets", preset_name))
return json_files
training_preset = gr.Dropdown(
label='Presets',
choices=list_presets('./presets/finetune'),
elem_id='myDropdown',
label="Presets",
choices=list_presets("./presets/finetune"),
elem_id="myDropdown",
)
with gr.Tab('Basic', elem_id='basic_tab'):
with gr.Tab("Basic", elem_id="basic_tab"):
basic_training = BasicTraining(
learning_rate_value='1e-5', finetuning=True, sdxl_checkbox=source_model.sdxl_checkbox,
learning_rate_value="1e-5",
finetuning=True,
sdxl_checkbox=source_model.sdxl_checkbox,
)
# Add SDXL Parameters
sdxl_params = SDXLParameters(source_model.sdxl_checkbox)
with gr.Row():
dataset_repeats = gr.Textbox(
label='Dataset repeats', value=40
)
dataset_repeats = gr.Textbox(label="Dataset repeats", value=40)
train_text_encoder = gr.Checkbox(
label='Train text encoder', value=True
label="Train text encoder", value=True
)
with gr.Tab('Advanced', elem_id='advanced_tab'):
with gr.Tab("Advanced", elem_id="advanced_tab"):
with gr.Row():
gradient_accumulation_steps = gr.Number(
label='Gradient accumulate steps', value='1'
label="Gradient accumulate steps", value="1"
)
block_lr = gr.Textbox(
label='Block LR',
placeholder='(Optional)',
info='Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3',
label="Block LR",
placeholder="(Optional)",
info="Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3",
)
advanced_training = AdvancedTraining(
headless=headless, finetuning=True
)
advanced_training = AdvancedTraining(headless=headless, finetuning=True)
advanced_training.color_aug.change(
color_aug_changed,
inputs=[advanced_training.color_aug],
@ -893,15 +887,15 @@ def finetune_tab(headless=False):
], # Not applicable to fine_tune.py
)
with gr.Tab('Samples', elem_id='samples_tab'):
with gr.Tab("Samples", elem_id="samples_tab"):
sample = SampleImages()
with gr.Row():
button_run = gr.Button('Start training', variant='primary')
button_run = gr.Button("Start training", variant="primary")
button_stop_training = gr.Button('Stop training')
button_stop_training = gr.Button("Stop training")
button_print = gr.Button('Print training command')
button_print = gr.Button("Print training command")
# Setup gradio tensorboard buttons
(
@ -1020,9 +1014,7 @@ def finetune_tab(headless=False):
inputs=[dummy_db_true, dummy_db_false, config.config_file_name]
+ settings_list
+ [training_preset],
outputs=[config.config_file_name]
+ settings_list
+ [training_preset],
outputs=[config.config_file_name] + settings_list + [training_preset],
show_progress=False,
)
@ -1038,9 +1030,7 @@ def finetune_tab(headless=False):
inputs=[dummy_db_false, dummy_db_false, config.config_file_name]
+ settings_list
+ [training_preset],
outputs=[config.config_file_name]
+ settings_list
+ [training_preset],
outputs=[config.config_file_name] + settings_list + [training_preset],
show_progress=False,
)
@ -1056,9 +1046,7 @@ def finetune_tab(headless=False):
inputs=[dummy_db_false, dummy_db_true, config.config_file_name]
+ settings_list
+ [training_preset],
outputs=[gr.Textbox()]
+ settings_list
+ [training_preset],
outputs=[gr.Textbox()] + settings_list + [training_preset],
show_progress=False,
)
@ -1090,94 +1078,84 @@ def finetune_tab(headless=False):
show_progress=False,
)
with gr.Tab('Guides'):
gr.Markdown(
'This section provide Various Finetuning guides and information...'
)
top_level_path = './docs/Finetuning/top_level.md'
with gr.Tab("Guides"):
gr.Markdown("This section provide Various Finetuning guides and information...")
top_level_path = "./docs/Finetuning/top_level.md"
if os.path.exists(top_level_path):
with open(
os.path.join(top_level_path), 'r', encoding='utf8'
) as file:
guides_top_level = file.read() + '\n'
with open(os.path.join(top_level_path), "r", encoding="utf8") as file:
guides_top_level = file.read() + "\n"
gr.Markdown(guides_top_level)
def UI(**kwargs):
add_javascript(kwargs.get('language'))
css = ''
add_javascript(kwargs.get("language"))
css = ""
headless = kwargs.get('headless', False)
log.info(f'headless: {headless}')
headless = kwargs.get("headless", False)
log.info(f"headless: {headless}")
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
log.info('Load CSS...')
css += file.read() + '\n'
if os.path.exists("./style.css"):
with open(os.path.join("./style.css"), "r", encoding="utf8") as file:
log.info("Load CSS...")
css += file.read() + "\n"
interface = gr.Blocks(
css=css, title='Kohya_ss GUI', theme=gr.themes.Default()
)
interface = gr.Blocks(css=css, title="Kohya_ss GUI", theme=gr.themes.Default())
with interface:
with gr.Tab('Finetune'):
with gr.Tab("Finetune"):
finetune_tab(headless=headless)
with gr.Tab('Utilities'):
with gr.Tab("Utilities"):
utilities_tab(enable_dreambooth_tab=False, headless=headless)
# Show the interface
launch_kwargs = {}
username = kwargs.get('username')
password = kwargs.get('password')
server_port = kwargs.get('server_port', 0)
inbrowser = kwargs.get('inbrowser', False)
share = kwargs.get('share', False)
server_name = kwargs.get('listen')
username = kwargs.get("username")
password = kwargs.get("password")
server_port = kwargs.get("server_port", 0)
inbrowser = kwargs.get("inbrowser", False)
share = kwargs.get("share", False)
server_name = kwargs.get("listen")
launch_kwargs['server_name'] = server_name
launch_kwargs["server_name"] = server_name
if username and password:
launch_kwargs['auth'] = (username, password)
launch_kwargs["auth"] = (username, password)
if server_port > 0:
launch_kwargs['server_port'] = server_port
launch_kwargs["server_port"] = server_port
if inbrowser:
launch_kwargs['inbrowser'] = inbrowser
launch_kwargs["inbrowser"] = inbrowser
if share:
launch_kwargs['share'] = share
launch_kwargs["share"] = share
interface.launch(**launch_kwargs)
if __name__ == '__main__':
if __name__ == "__main__":
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
parser.add_argument(
'--listen',
"--listen",
type=str,
default='127.0.0.1',
help='IP to listen on for connections to Gradio',
default="127.0.0.1",
help="IP to listen on for connections to Gradio",
)
parser.add_argument(
'--username', type=str, default='', help='Username for authentication'
"--username", type=str, default="", help="Username for authentication"
)
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
"--password", type=str, default="", help="Password for authentication"
)
parser.add_argument(
'--server_port',
"--server_port",
type=int,
default=0,
help='Port to run the server listener on',
help="Port to run the server listener on",
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
parser.add_argument("--share", action="store_true", help="Share the gradio UI")
parser.add_argument(
"--headless", action="store_true", help="Is the server headless"
)
parser.add_argument(
'--inbrowser', action='store_true', help='Open in browser'
)
parser.add_argument(
'--share', action='store_true', help='Share the gradio UI'
)
parser.add_argument(
'--headless', action='store_true', help='Is the server headless'
)
parser.add_argument(
'--language', type=str, default=None, help='Set custom language'
"--language", type=str, default=None, help="Set custom language"
)
args = parser.parse_args()

View File

@ -678,11 +678,11 @@ def get_str_or_default(kwargs, key, default_value=""):
def run_cmd_advanced_training(**kwargs):
run_cmd = ""
additional_parameters = kwargs.get("additional_parameters")
if additional_parameters:
run_cmd += f" {additional_parameters}"
block_lr = kwargs.get("block_lr")
if block_lr:
run_cmd += f' --block_lr="(block_lr)"'
@ -702,7 +702,7 @@ def run_cmd_advanced_training(**kwargs):
cache_latents_to_disk = kwargs.get("cache_latents_to_disk")
if cache_latents_to_disk:
run_cmd += " --cache_latents_to_disk"
cache_text_encoder_outputs = kwargs.get("cache_text_encoder_outputs")
if cache_text_encoder_outputs:
run_cmd += " --cache_text_encoder_outputs"
@ -728,18 +728,18 @@ def run_cmd_advanced_training(**kwargs):
color_aug = kwargs.get("color_aug")
if color_aug:
run_cmd += " --color_aug"
dataset_repeats = kwargs.get("dataset_repeats")
if dataset_repeats:
run_cmd += f' --dataset_repeats="{dataset_repeats}"'
enable_bucket = kwargs.get("enable_bucket")
if enable_bucket:
min_bucket_reso = kwargs.get("min_bucket_reso")
max_bucket_reso = kwargs.get("max_bucket_reso")
if min_bucket_reso and max_bucket_reso:
run_cmd += f" --enable_bucket --min_bucket_reso={min_bucket_reso} --max_bucket_reso={max_bucket_reso}"
in_json = kwargs.get("in_json")
if in_json:
run_cmd += f' --in_json="{in_json}"'
@ -751,7 +751,7 @@ def run_cmd_advanced_training(**kwargs):
fp8_base = kwargs.get("fp8_base")
if fp8_base:
run_cmd += " --fp8_base"
full_bf16 = kwargs.get("full_bf16")
if full_bf16:
run_cmd += " --full_bf16"
@ -759,7 +759,7 @@ def run_cmd_advanced_training(**kwargs):
full_fp16 = kwargs.get("full_fp16")
if full_fp16:
run_cmd += " --full_fp16"
gradient_accumulation_steps = kwargs.get("gradient_accumulation_steps")
if gradient_accumulation_steps and int(gradient_accumulation_steps) > 1:
run_cmd += f" --gradient_accumulation_steps={int(gradient_accumulation_steps)}"
@ -775,19 +775,19 @@ def run_cmd_advanced_training(**kwargs):
learning_rate = kwargs.get("learning_rate")
if learning_rate:
run_cmd += f' --learning_rate="{learning_rate}"'
learning_rate_te = kwargs.get("learning_rate_te")
if learning_rate_te:
run_cmd += f' --learning_rate_te="{learning_rate_te}"'
learning_rate_te1 = kwargs.get("learning_rate_te1")
if learning_rate_te1:
run_cmd += f' --learning_rate_te1="{learning_rate_te1}"'
learning_rate_te2 = kwargs.get("learning_rate_te2")
if learning_rate_te2:
run_cmd += f' --learning_rate_te2="{learning_rate_te2}"'
logging_dir = kwargs.get("logging_dir")
if logging_dir:
run_cmd += f' --logging_dir="{logging_dir}"'
@ -799,7 +799,7 @@ def run_cmd_advanced_training(**kwargs):
lr_scheduler_args = kwargs.get("lr_scheduler_args")
if lr_scheduler_args and lr_scheduler_args != "":
run_cmd += f" --lr_scheduler_args {lr_scheduler_args}"
lr_scheduler_num_cycles = kwargs.get("lr_scheduler_num_cycles")
if lr_scheduler_num_cycles and not lr_scheduler_num_cycles == "":
run_cmd += f' --lr_scheduler_num_cycles="{lr_scheduler_num_cycles}"'
@ -807,7 +807,7 @@ def run_cmd_advanced_training(**kwargs):
epoch = kwargs.get("epoch")
if epoch:
run_cmd += f' --lr_scheduler_num_cycles="{epoch}"'
lr_scheduler_power = kwargs.get("lr_scheduler_power")
if lr_scheduler_power and not lr_scheduler_power == "":
run_cmd += f' --lr_scheduler_power="{lr_scheduler_power}"'
@ -830,7 +830,7 @@ def run_cmd_advanced_training(**kwargs):
max_grad_norm = kwargs.get("max_grad_norm")
if max_grad_norm and max_grad_norm != "":
run_cmd += f' --max_grad_norm="{max_grad_norm}"'
max_resolution = kwargs.get("max_resolution")
if max_resolution:
run_cmd += f' --resolution="{max_resolution}"'
@ -844,7 +844,7 @@ def run_cmd_advanced_training(**kwargs):
run_cmd += f" --max_token_length={int(max_token_length)}"
max_train_epochs = kwargs.get("max_train_epochs")
if max_train_epochs and not max_train_epochs == '':
if max_train_epochs and not max_train_epochs == "":
run_cmd += f" --max_train_epochs={max_train_epochs}"
max_train_steps = kwargs.get("max_train_steps")
@ -870,15 +870,15 @@ def run_cmd_advanced_training(**kwargs):
multi_gpu = kwargs.get("multi_gpu")
if multi_gpu:
run_cmd += " --multi_gpu"
no_half_vae = kwargs.get("no_half_vae")
if no_half_vae:
run_cmd += " --no_half_vae"
no_token_padding = kwargs.get("no_token_padding")
if no_token_padding:
run_cmd += " --no_token_padding"
noise_offset_type = kwargs.get("noise_offset_type")
if noise_offset_type and noise_offset_type == "Original":
noise_offset = kwargs.get("noise_offset")
@ -886,17 +886,23 @@ def run_cmd_advanced_training(**kwargs):
run_cmd += f" --noise_offset={float(noise_offset)}"
adaptive_noise_scale = kwargs.get("adaptive_noise_scale")
if adaptive_noise_scale and float(adaptive_noise_scale) != 0 and float(noise_offset) > 0:
if (
adaptive_noise_scale
and float(adaptive_noise_scale) != 0
and float(noise_offset) > 0
):
run_cmd += f" --adaptive_noise_scale={float(adaptive_noise_scale)}"
elif noise_offset_type and noise_offset_type == "Multires":
multires_noise_iterations = kwargs.get("multires_noise_iterations")
if int(multires_noise_iterations) > 0:
run_cmd += f' --multires_noise_iterations="{int(multires_noise_iterations)}"'
run_cmd += (
f' --multires_noise_iterations="{int(multires_noise_iterations)}"'
)
multires_noise_discount = kwargs.get("multires_noise_discount")
if multires_noise_discount and float(multires_noise_discount) > 0:
run_cmd += f' --multires_noise_discount="{float(multires_noise_discount)}"'
num_machines = kwargs.get("num_machines")
if num_machines and int(num_machines) > 1:
run_cmd += f" --num_machines={int(num_machines)}"
@ -916,12 +922,12 @@ def run_cmd_advanced_training(**kwargs):
optimizer_type = kwargs.get("optimizer")
if optimizer_type:
run_cmd += f' --optimizer_type="{optimizer_type}"'
output_dir = kwargs.get("output_dir")
if output_dir:
run_cmd += f' --output_dir="{output_dir}"'
output_name = kwargs.get("output_name")
output_name = kwargs.get("output_name")
if output_name and not output_name == "":
run_cmd += f' --output_name="{output_name}"'
@ -932,7 +938,7 @@ def run_cmd_advanced_training(**kwargs):
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path")
if pretrained_model_name_or_path:
run_cmd += f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
prior_loss_weight = kwargs.get("prior_loss_weight")
if prior_loss_weight and not float(prior_loss_weight) == 1.0:
run_cmd += f" --prior_loss_weight={prior_loss_weight}"
@ -964,7 +970,7 @@ def run_cmd_advanced_training(**kwargs):
save_last_n_steps_state = kwargs.get("save_last_n_steps_state")
if save_last_n_steps_state and int(save_last_n_steps_state) > 0:
run_cmd += f' --save_last_n_steps_state="{int(save_last_n_steps_state)}"'
save_model_as = kwargs.get("save_model_as")
if save_model_as and not save_model_as == "same as source model":
run_cmd += f" --save_model_as={save_model_as}"
@ -988,7 +994,7 @@ def run_cmd_advanced_training(**kwargs):
shuffle_caption = kwargs.get("shuffle_caption")
if shuffle_caption:
run_cmd += " --shuffle_caption"
stop_text_encoder_training = kwargs.get("stop_text_encoder_training")
if stop_text_encoder_training and stop_text_encoder_training > 0:
run_cmd += f' --stop_text_encoder_training="{stop_text_encoder_training}"'
@ -996,11 +1002,11 @@ def run_cmd_advanced_training(**kwargs):
train_batch_size = kwargs.get("train_batch_size")
if train_batch_size:
run_cmd += f' --train_batch_size="{train_batch_size}"'
train_data_dir = kwargs.get("train_data_dir")
if train_data_dir:
run_cmd += f' --train_data_dir="{train_data_dir}"'
train_text_encoder = kwargs.get("train_text_encoder")
if train_text_encoder:
run_cmd += " --train_text_encoder"
@ -1008,7 +1014,7 @@ def run_cmd_advanced_training(**kwargs):
use_wandb = kwargs.get("use_wandb")
if use_wandb:
run_cmd += " --log_with wandb"
v_parameterization = kwargs.get("v_parameterization")
if v_parameterization:
run_cmd += " --v_parameterization"
@ -1016,7 +1022,7 @@ def run_cmd_advanced_training(**kwargs):
v_pred_like_loss = kwargs.get("v_pred_like_loss")
if v_pred_like_loss and float(v_pred_like_loss) > 0:
run_cmd += f' --v_pred_like_loss="{float(v_pred_like_loss)}"'
v2 = kwargs.get("v2")
if v2:
run_cmd += " --v2"
@ -1032,7 +1038,7 @@ def run_cmd_advanced_training(**kwargs):
wandb_api_key = kwargs.get("wandb_api_key")
if wandb_api_key:
run_cmd += f' --wandb_api_key="{wandb_api_key}"'
weighted_captions = kwargs.get("weighted_captions")
if weighted_captions:
run_cmd += " --weighted_captions"

View File

@ -732,14 +732,15 @@ def train_model(
log.info(f"lr_warmup_steps = {lr_warmup_steps}")
run_cmd = "accelerate launch"
run_cmd += run_cmd_advanced_training(
num_processes=num_processes,
num_machines=num_machines,
multi_gpu=multi_gpu,
gpu_ids=gpu_ids,
num_cpu_threads_per_process=num_cpu_threads_per_process)
num_cpu_threads_per_process=num_cpu_threads_per_process,
)
if sdxl:
run_cmd += f' "./sdxl_train_network.py"'
else:
@ -1803,7 +1804,9 @@ def lora_tab(
placeholder="(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2",
info="Specify the alpha of each block when expanding LoRA to Conv2d 3x3. Specify 25 numbers. If omitted, the value of conv_alpha is used.",
)
advanced_training = AdvancedTraining(headless=headless, training_type="lora")
advanced_training = AdvancedTraining(
headless=headless, training_type="lora"
)
advanced_training.color_aug.change(
color_aug_changed,
inputs=[advanced_training.color_aug],

View File

@ -153,19 +153,19 @@ def save_configuration(
original_file_path = file_path
save_as_bool = True if save_as.get('label') == 'True' else False
save_as_bool = True if save_as.get("label") == "True" else False
if save_as_bool:
log.info('Save as...')
log.info("Save as...")
file_path = get_saveasfile_path(file_path)
else:
log.info('Save...')
if file_path == None or file_path == '':
log.info("Save...")
if file_path == None or file_path == "":
file_path = get_saveasfile_path(file_path)
# log.info(file_path)
if file_path == None or file_path == '':
if file_path == None or file_path == "":
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
# Extract the destination directory from the file path
@ -178,7 +178,7 @@ def save_configuration(
SaveConfigFile(
parameters=parameters,
file_path=file_path,
exclusion=['file_path', 'save_as'],
exclusion=["file_path", "save_as"],
)
return file_path
@ -282,18 +282,18 @@ def open_configuration(
# Get list of function parameters and values
parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False
ask_for_file = True if ask_for_file.get("label") == "True" else False
original_file_path = file_path
if ask_for_file:
file_path = get_file_path(file_path)
if not file_path == '' and not file_path == None:
if not file_path == "" and not file_path == None:
# load variables from JSON file
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
my_data = json.load(f)
log.info('Loading config...')
log.info("Loading config...")
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
my_data = update_my_data(my_data)
else:
@ -303,7 +303,7 @@ def open_configuration(
values = [file_path]
for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['ask_for_file', 'file_path']:
if not key in ["ask_for_file", "file_path"]:
values.append(my_data.get(key, value))
return tuple(values)
@ -406,36 +406,32 @@ def train_model(
# Get list of function parameters and values
parameters = list(locals().items())
print_only_bool = True if print_only.get('label') == 'True' else False
log.info(f'Start training TI...')
print_only_bool = True if print_only.get("label") == "True" else False
log.info(f"Start training TI...")
headless_bool = True if headless.get('label') == 'True' else False
headless_bool = True if headless.get("label") == "True" else False
if pretrained_model_name_or_path == '':
if pretrained_model_name_or_path == "":
output_message(
msg='Source model information is missing', headless=headless_bool
msg="Source model information is missing", headless=headless_bool
)
return
if train_data_dir == '':
output_message(
msg='Image folder path is missing', headless=headless_bool
)
if train_data_dir == "":
output_message(msg="Image folder path is missing", headless=headless_bool)
return
if not os.path.exists(train_data_dir):
output_message(
msg='Image folder does not exist', headless=headless_bool
)
output_message(msg="Image folder does not exist", headless=headless_bool)
return
if not verify_image_folder_pattern(train_data_dir):
return
if reg_data_dir != '':
if reg_data_dir != "":
if not os.path.exists(reg_data_dir):
output_message(
msg='Regularisation folder does not exist',
msg="Regularisation folder does not exist",
headless=headless_bool,
)
return
@ -443,26 +439,22 @@ def train_model(
if not verify_image_folder_pattern(reg_data_dir):
return
if output_dir == '':
output_message(
msg='Output folder path is missing', headless=headless_bool
)
if output_dir == "":
output_message(msg="Output folder path is missing", headless=headless_bool)
return
if token_string == '':
output_message(msg='Token string is missing', headless=headless_bool)
if token_string == "":
output_message(msg="Token string is missing", headless=headless_bool)
return
if init_word == '':
output_message(msg='Init word is missing', headless=headless_bool)
if init_word == "":
output_message(msg="Init word is missing", headless=headless_bool)
return
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if check_if_model_exist(
output_name, output_dir, save_model_as, headless_bool
):
if check_if_model_exist(output_name, output_dir, save_model_as, headless_bool):
return
# if float(noise_offset) > 0 and (
@ -495,7 +487,7 @@ def train_model(
# Loop through each subfolder and extract the number of repeats
for folder in subfolders:
# Extract the number of repeats from the folder name
repeats = int(folder.split('_')[0])
repeats = int(folder.split("_")[0])
# Count the number of images in the folder
num_images = len(
@ -503,11 +495,9 @@ def train_model(
f
for f, lower_f in (
(file, file.lower())
for file in os.listdir(
os.path.join(train_data_dir, folder)
)
for file in os.listdir(os.path.join(train_data_dir, folder))
)
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp"))
]
)
@ -516,21 +506,21 @@ def train_model(
total_steps += steps
# Print the result
log.info(f'Folder {folder}: {steps} steps')
log.info(f"Folder {folder}: {steps} steps")
# Print the result
# log.info(f"{total_steps} total steps")
if reg_data_dir == '':
if reg_data_dir == "":
reg_factor = 1
else:
log.info(
'Regularisation images are used... Will double the number of steps required...'
"Regularisation images are used... Will double the number of steps required..."
)
reg_factor = 2
# calculate max_train_steps
if max_train_steps == '' or max_train_steps == '0':
if max_train_steps == "" or max_train_steps == "0":
max_train_steps = int(
math.ceil(
float(total_steps)
@ -543,7 +533,7 @@ def train_model(
else:
max_train_steps = int(max_train_steps)
log.info(f'max_train_steps = {max_train_steps}')
log.info(f"max_train_steps = {max_train_steps}")
# calculate stop encoder training
if stop_text_encoder_training_pct == None:
@ -552,20 +542,21 @@ def train_model(
stop_text_encoder_training = math.ceil(
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
)
log.info(f'stop_text_encoder_training = {stop_text_encoder_training}')
log.info(f"stop_text_encoder_training = {stop_text_encoder_training}")
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
log.info(f'lr_warmup_steps = {lr_warmup_steps}')
log.info(f"lr_warmup_steps = {lr_warmup_steps}")
run_cmd = "accelerate launch"
run_cmd += run_cmd_advanced_training(
num_processes=num_processes,
num_machines=num_machines,
multi_gpu=multi_gpu,
gpu_ids=gpu_ids,
num_cpu_threads_per_process=num_cpu_threads_per_process)
num_cpu_threads_per_process=num_cpu_threads_per_process,
)
if sdxl:
run_cmd += f' "./sdxl_train_textual_inversion.py"'
else:
@ -649,13 +640,13 @@ def train_model(
)
run_cmd += f' --token_string="{token_string}"'
run_cmd += f' --init_word="{init_word}"'
run_cmd += f' --num_vectors_per_token={num_vectors_per_token}'
if not weights == '':
run_cmd += f" --num_vectors_per_token={num_vectors_per_token}"
if not weights == "":
run_cmd += f' --weights="{weights}"'
if template == 'object template':
run_cmd += f' --use_object_template'
elif template == 'style template':
run_cmd += f' --use_style_template'
if template == "object template":
run_cmd += f" --use_object_template"
elif template == "style template":
run_cmd += f" --use_style_template"
run_cmd += run_cmd_sample(
sample_every_n_steps,
@ -667,7 +658,7 @@ def train_model(
if print_only_bool:
log.warning(
'Here is the trainer command as a reference. It will not be executed:\n'
"Here is the trainer command as a reference. It will not be executed:\n"
)
print(run_cmd)
@ -675,17 +666,15 @@ def train_model(
else:
# Saving config file for model
current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime('%Y%m%d-%H%M%S')
file_path = os.path.join(
output_dir, f'{output_name}_{formatted_datetime}.json'
)
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json")
log.info(f'Saving training config to {file_path}...')
log.info(f"Saving training config to {file_path}...")
SaveConfigFile(
parameters=parameters,
file_path=file_path,
exclusion=['file_path', 'save_as', 'headless', 'print_only'],
exclusion=["file_path", "save_as", "headless", "print_only"],
)
log.info(run_cmd)
@ -695,13 +684,11 @@ def train_model(
executor.execute_command(run_cmd=run_cmd)
# check if output_dir/last is a folder... therefore it is a diffuser model
last_dir = pathlib.Path(f'{output_dir}/{output_name}')
last_dir = pathlib.Path(f"{output_dir}/{output_name}")
if not last_dir.is_dir():
# Copy inference model for v2 if required
save_inference_file(
output_dir, v2, v_parameterization, output_name
)
save_inference_file(output_dir, v2, v_parameterization, output_name)
def ti_tab(
@ -711,32 +698,32 @@ def ti_tab(
dummy_db_false = gr.Label(value=False, visible=False)
dummy_headless = gr.Label(value=headless, visible=False)
with gr.Tab('Training'):
gr.Markdown('Train a TI using kohya textual inversion python code...')
with gr.Tab("Training"):
gr.Markdown("Train a TI using kohya textual inversion python code...")
# Setup Configuration Files Gradio
config = ConfigurationFile(headless)
source_model = SourceModel(
save_model_as_choices=[
'ckpt',
'safetensors',
"ckpt",
"safetensors",
],
headless=headless,
)
with gr.Tab('Folders'):
with gr.Tab("Folders"):
folders = Folders(headless=headless)
with gr.Tab('Parameters'):
with gr.Tab('Basic', elem_id='basic_tab'):
with gr.Tab("Parameters"):
with gr.Tab("Basic", elem_id="basic_tab"):
with gr.Row():
weights = gr.Textbox(
label='Resume TI training',
placeholder='(Optional) Path to existing TI embeding file to keep training',
label="Resume TI training",
placeholder="(Optional) Path to existing TI embeding file to keep training",
)
weights_file_input = gr.Button(
'📂',
elem_id='open_folder_small',
"",
elem_id="open_folder_small",
visible=(not headless),
)
weights_file_input.click(
@ -746,37 +733,37 @@ def ti_tab(
)
with gr.Row():
token_string = gr.Textbox(
label='Token string',
placeholder='eg: cat',
label="Token string",
placeholder="eg: cat",
)
init_word = gr.Textbox(
label='Init word',
value='*',
label="Init word",
value="*",
)
num_vectors_per_token = gr.Slider(
minimum=1,
maximum=75,
value=1,
step=1,
label='Vectors',
label="Vectors",
)
# max_train_steps = gr.Textbox(
# label='Max train steps',
# placeholder='(Optional) Maximum number of steps',
# )
template = gr.Dropdown(
label='Template',
label="Template",
choices=[
'caption',
'object template',
'style template',
"caption",
"object template",
"style template",
],
value='caption',
value="caption",
)
basic_training = BasicTraining(
learning_rate_value='1e-5',
lr_scheduler_value='cosine',
lr_warmup_value='10',
learning_rate_value="1e-5",
lr_scheduler_value="cosine",
lr_warmup_value="10",
sdxl_checkbox=source_model.sdxl_checkbox,
)
@ -786,7 +773,7 @@ def ti_tab(
show_sdxl_cache_text_encoder_outputs=False,
)
with gr.Tab('Advanced', elem_id='advanced_tab'):
with gr.Tab("Advanced", elem_id="advanced_tab"):
advanced_training = AdvancedTraining(headless=headless)
advanced_training.color_aug.change(
color_aug_changed,
@ -794,12 +781,12 @@ def ti_tab(
outputs=[basic_training.cache_latents],
)
with gr.Tab('Samples', elem_id='samples_tab'):
with gr.Tab("Samples", elem_id="samples_tab"):
sample = SampleImages()
with gr.Tab('Dataset Preparation'):
with gr.Tab("Dataset Preparation"):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
"This section provide Dreambooth tools to help setup your dataset..."
)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=folders.train_data_dir,
@ -811,11 +798,11 @@ def ti_tab(
gradio_dataset_balancing_tab(headless=headless)
with gr.Row():
button_run = gr.Button('Start training', variant='primary')
button_run = gr.Button("Start training", variant="primary")
button_stop_training = gr.Button('Stop training')
button_stop_training = gr.Button("Stop training")
button_print = gr.Button('Print training command')
button_print = gr.Button("Print training command")
# Setup gradio tensorboard buttons
(
@ -978,30 +965,28 @@ def ti_tab(
def UI(**kwargs):
add_javascript(kwargs.get('language'))
css = ''
add_javascript(kwargs.get("language"))
css = ""
headless = kwargs.get('headless', False)
log.info(f'headless: {headless}')
headless = kwargs.get("headless", False)
log.info(f"headless: {headless}")
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
log.info('Load CSS...')
css += file.read() + '\n'
if os.path.exists("./style.css"):
with open(os.path.join("./style.css"), "r", encoding="utf8") as file:
log.info("Load CSS...")
css += file.read() + "\n"
interface = gr.Blocks(
css=css, title='Kohya_ss GUI', theme=gr.themes.Default()
)
interface = gr.Blocks(css=css, title="Kohya_ss GUI", theme=gr.themes.Default())
with interface:
with gr.Tab('Dreambooth TI'):
with gr.Tab("Dreambooth TI"):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = ti_tab(headless=headless)
with gr.Tab('Utilities'):
with gr.Tab("Utilities"):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
@ -1013,57 +998,53 @@ def UI(**kwargs):
# Show the interface
launch_kwargs = {}
username = kwargs.get('username')
password = kwargs.get('password')
server_port = kwargs.get('server_port', 0)
inbrowser = kwargs.get('inbrowser', False)
share = kwargs.get('share', False)
server_name = kwargs.get('listen')
username = kwargs.get("username")
password = kwargs.get("password")
server_port = kwargs.get("server_port", 0)
inbrowser = kwargs.get("inbrowser", False)
share = kwargs.get("share", False)
server_name = kwargs.get("listen")
launch_kwargs['server_name'] = server_name
launch_kwargs["server_name"] = server_name
if username and password:
launch_kwargs['auth'] = (username, password)
launch_kwargs["auth"] = (username, password)
if server_port > 0:
launch_kwargs['server_port'] = server_port
launch_kwargs["server_port"] = server_port
if inbrowser:
launch_kwargs['inbrowser'] = inbrowser
launch_kwargs["inbrowser"] = inbrowser
if share:
launch_kwargs['share'] = share
launch_kwargs["share"] = share
interface.launch(**launch_kwargs)
if __name__ == '__main__':
if __name__ == "__main__":
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
parser.add_argument(
'--listen',
"--listen",
type=str,
default='127.0.0.1',
help='IP to listen on for connections to Gradio',
default="127.0.0.1",
help="IP to listen on for connections to Gradio",
)
parser.add_argument(
'--username', type=str, default='', help='Username for authentication'
"--username", type=str, default="", help="Username for authentication"
)
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
"--password", type=str, default="", help="Password for authentication"
)
parser.add_argument(
'--server_port',
"--server_port",
type=int,
default=0,
help='Port to run the server listener on',
help="Port to run the server listener on",
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
parser.add_argument("--share", action="store_true", help="Share the gradio UI")
parser.add_argument(
"--headless", action="store_true", help="Is the server headless"
)
parser.add_argument(
'--inbrowser', action='store_true', help='Open in browser'
)
parser.add_argument(
'--share', action='store_true', help='Share the gradio UI'
)
parser.add_argument(
'--headless', action='store_true', help='Is the server headless'
)
parser.add_argument(
'--language', type=str, default=None, help='Set custom language'
"--language", type=str, default=None, help="Set custom language"
)
args = parser.parse_args()