From 49b6cd3adfe25ce711abcb01f89bc9a60215ce4e Mon Sep 17 00:00:00 2001 From: Tang Jie Date: Thu, 15 Jun 2023 10:43:42 +0000 Subject: [PATCH] fix: dreambooth config save. --- dreambooth_on_cloud/train.py | 15 ++++++++------- javascript/dreambooth_on_cloud.js | 21 +++++++++++++++++++-- scripts/main.py | 5 +++-- 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/dreambooth_on_cloud/train.py b/dreambooth_on_cloud/train.py index fe24b5d1..1961fd99 100644 --- a/dreambooth_on_cloud/train.py +++ b/dreambooth_on_cloud/train.py @@ -6,6 +6,7 @@ import gradio as gr import os import sys import logging +import shutil from utils import upload_file_to_s3_by_presign_url from utils import get_variable_from_json @@ -80,20 +81,18 @@ def get_cloud_db_model_name_list(): model_name_list = [model['model_name'] for model in model_list] return model_name_list -def hack_db_config(db_config, db_config_file_path, model_name, data_list, class_data_list): +def hack_db_config(db_config, db_config_file_path, model_name, data_list, class_data_list, local_model_name): for k in db_config: if k == "model_dir": db_config[k] = re.sub(".+/(models/dreambooth/).+$", f"\\1{model_name}", db_config[k]) elif k == "pretrained_model_name_or_path": db_config[k] = re.sub(".+/(models/dreambooth/).+(working)$", f"\\1{model_name}/\\2", db_config[k]) elif k == "model_name": - db_config[k] = db_config[k].replace("dummy_local_model", model_name) + db_config[k] = db_config[k].replace(local_model_name, model_name) elif k == "concepts_list": for concept, data, class_data in zip(db_config[k], data_list, class_data_list): concept["instance_data_dir"] = data concept["class_data_dir"] = class_data - # else: - # db_config[k] = db_config[k].replace("dummy_local_model", model_name) with open(db_config_file_path, "w") as db_config_file_w: json.dump(db_config, db_config_file_w) @@ -194,8 +193,9 @@ def wrap_save_config(model_name): setattr(dreambooth_shared, 'dreambooth_models_path', origin_model_path) def cloud_train( + local_model_name: str, train_model_name: str, - local_model_name=False, + db_use_txt2img=False, training_instance_type: str= "" ): integral_check = False @@ -210,7 +210,7 @@ def cloud_train( try: # Get data path and class data path. print(f"Start cloud training {train_model_name}") - db_config_path = os.path.join("models/dreambooth/dummy_local_model/db_config.json") + db_config_path = os.path.join(f"models/dreambooth/{local_model_name}/db_config.json") with open(db_config_path) as db_config_file: config = json.load(db_config_file) local_data_path_list = [] @@ -224,7 +224,8 @@ def cloud_train( class_data_path_list.append(concept["class_data_dir"].replace("s3://", "").replace("/", "-").strip("-")) model_list = get_cloud_db_models() new_db_config_path = os.path.join(base_model_folder, f"{train_model_name}/db_config_cloud.json") - hack_db_config(config, new_db_config_path, train_model_name, data_path_list, class_data_path_list) + print(f"hack config from local_model_name to new_db_config_path") + hack_db_config(config, new_db_config_path, train_model_name, data_path_list, class_data_path_list, local_model_name) if config["save_lora_for_extra_net"] == True: model_type = "Lora" else: diff --git a/javascript/dreambooth_on_cloud.js b/javascript/dreambooth_on_cloud.js index a2711e69..ea2341e6 100644 --- a/javascript/dreambooth_on_cloud.js +++ b/javascript/dreambooth_on_cloud.js @@ -4,9 +4,26 @@ function sleep(ms) { return new Promise(resolve => setTimeout(resolve, ms)); } +function getElementByXpath(path) { + console.log(path) + return document.evaluate(path, document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue; +} + +function set_dropdown_value(xpath, value) { + let dropdown = getElementByXpath(xpath) + console.log("Trying to click the dropdown " + dropdown) + dropdown.click() + let target_dropdown = getElementByXpath(`//ul[contains(.,'${value}')]`) + console.log("Trying to set the value of dropdown" + target_dropdown) + target_dropdown.click() +} + async function db_start_sagemaker_train() { console.log("Sagemaker training"); console.log(arguments); + // var xpath = "//*[@id='cloud_db_model_name']/label/div/div[1]/div" + // var value = "dummy_local_model" + // set_dropdown_value(xpath, value) // pop up confirmation for sagemaker training let do_save = confirm("Confirm to start Sagemaker training? This will take a while."); @@ -14,10 +31,10 @@ async function db_start_sagemaker_train() { return; } save_config(); - await sleep(5000); + await sleep(1000); // let sagemaker_train = gradioApp().getElementById("db_sagemaker_train"); // sagemaker_train.style.display = "block"; - return filterArgs(3, arguments) + return filterArgs(4, arguments) } function check_create_model_params() { diff --git a/scripts/main.py b/scripts/main.py index 76128f02..acd0dca1 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -117,6 +117,7 @@ def on_after_component_callback(component, **_kwargs): fn=async_cloud_train, _js="db_start_sagemaker_train", inputs=[ + db_model_name, cloud_db_model_name, db_use_txt2img, cloud_train_instance_type @@ -637,7 +638,7 @@ def ui_tabs_callback(): with gr.Row(): cloud_train_instance_type = gr.Dropdown( label="SageMaker Train Instance Type", - choices=['ml.g4dn.2xlarge'], + choices=['ml.g4dn.2xlarge', 'ml.g5.2xlarge'], elem_id="cloud_train_instance_type", info='select SageMaker Train Instance Type' ) @@ -697,7 +698,7 @@ def ui_tabs_callback(): cloud_db_new_model_name = gr.Textbox(label="Name", placeholder="Model names can only contain alphanumeric and -") with gr.Row(): cloud_db_create_from_hub = gr.Checkbox( - label="Create From Hub", value=False + label="Create From Hub", value=False, visible=False ) cloud_db_512_model = gr.Checkbox(label="512x Model", value=True) with gr.Column(visible=False) as hub_row: