fix: dreambooth config save.

pull/43/head
Tang Jie 2023-06-15 10:43:42 +00:00
parent d0c095dee5
commit 49b6cd3adf
3 changed files with 30 additions and 11 deletions

View File

@ -6,6 +6,7 @@ import gradio as gr
import os
import sys
import logging
import shutil
from utils import upload_file_to_s3_by_presign_url
from utils import get_variable_from_json
@ -80,20 +81,18 @@ def get_cloud_db_model_name_list():
model_name_list = [model['model_name'] for model in model_list]
return model_name_list
def hack_db_config(db_config, db_config_file_path, model_name, data_list, class_data_list):
def hack_db_config(db_config, db_config_file_path, model_name, data_list, class_data_list, local_model_name):
for k in db_config:
if k == "model_dir":
db_config[k] = re.sub(".+/(models/dreambooth/).+$", f"\\1{model_name}", db_config[k])
elif k == "pretrained_model_name_or_path":
db_config[k] = re.sub(".+/(models/dreambooth/).+(working)$", f"\\1{model_name}/\\2", db_config[k])
elif k == "model_name":
db_config[k] = db_config[k].replace("dummy_local_model", model_name)
db_config[k] = db_config[k].replace(local_model_name, model_name)
elif k == "concepts_list":
for concept, data, class_data in zip(db_config[k], data_list, class_data_list):
concept["instance_data_dir"] = data
concept["class_data_dir"] = class_data
# else:
# db_config[k] = db_config[k].replace("dummy_local_model", model_name)
with open(db_config_file_path, "w") as db_config_file_w:
json.dump(db_config, db_config_file_w)
@ -194,8 +193,9 @@ def wrap_save_config(model_name):
setattr(dreambooth_shared, 'dreambooth_models_path', origin_model_path)
def cloud_train(
local_model_name: str,
train_model_name: str,
local_model_name=False,
db_use_txt2img=False,
training_instance_type: str= ""
):
integral_check = False
@ -210,7 +210,7 @@ def cloud_train(
try:
# Get data path and class data path.
print(f"Start cloud training {train_model_name}")
db_config_path = os.path.join("models/dreambooth/dummy_local_model/db_config.json")
db_config_path = os.path.join(f"models/dreambooth/{local_model_name}/db_config.json")
with open(db_config_path) as db_config_file:
config = json.load(db_config_file)
local_data_path_list = []
@ -224,7 +224,8 @@ def cloud_train(
class_data_path_list.append(concept["class_data_dir"].replace("s3://", "").replace("/", "-").strip("-"))
model_list = get_cloud_db_models()
new_db_config_path = os.path.join(base_model_folder, f"{train_model_name}/db_config_cloud.json")
hack_db_config(config, new_db_config_path, train_model_name, data_path_list, class_data_path_list)
print(f"hack config from local_model_name to new_db_config_path")
hack_db_config(config, new_db_config_path, train_model_name, data_path_list, class_data_path_list, local_model_name)
if config["save_lora_for_extra_net"] == True:
model_type = "Lora"
else:

View File

@ -4,9 +4,26 @@ function sleep(ms) {
return new Promise(resolve => setTimeout(resolve, ms));
}
function getElementByXpath(path) {
console.log(path)
return document.evaluate(path, document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue;
}
function set_dropdown_value(xpath, value) {
let dropdown = getElementByXpath(xpath)
console.log("Trying to click the dropdown " + dropdown)
dropdown.click()
let target_dropdown = getElementByXpath(`//ul[contains(.,'${value}')]`)
console.log("Trying to set the value of dropdown" + target_dropdown)
target_dropdown.click()
}
async function db_start_sagemaker_train() {
console.log("Sagemaker training");
console.log(arguments);
// var xpath = "//*[@id='cloud_db_model_name']/label/div/div[1]/div"
// var value = "dummy_local_model"
// set_dropdown_value(xpath, value)
// pop up confirmation for sagemaker training
let do_save = confirm("Confirm to start Sagemaker training? This will take a while.");
@ -14,10 +31,10 @@ async function db_start_sagemaker_train() {
return;
}
save_config();
await sleep(5000);
await sleep(1000);
// let sagemaker_train = gradioApp().getElementById("db_sagemaker_train");
// sagemaker_train.style.display = "block";
return filterArgs(3, arguments)
return filterArgs(4, arguments)
}
function check_create_model_params() {

View File

@ -117,6 +117,7 @@ def on_after_component_callback(component, **_kwargs):
fn=async_cloud_train,
_js="db_start_sagemaker_train",
inputs=[
db_model_name,
cloud_db_model_name,
db_use_txt2img,
cloud_train_instance_type
@ -637,7 +638,7 @@ def ui_tabs_callback():
with gr.Row():
cloud_train_instance_type = gr.Dropdown(
label="SageMaker Train Instance Type",
choices=['ml.g4dn.2xlarge'],
choices=['ml.g4dn.2xlarge', 'ml.g5.2xlarge'],
elem_id="cloud_train_instance_type",
info='select SageMaker Train Instance Type'
)
@ -697,7 +698,7 @@ def ui_tabs_callback():
cloud_db_new_model_name = gr.Textbox(label="Name", placeholder="Model names can only contain alphanumeric and -")
with gr.Row():
cloud_db_create_from_hub = gr.Checkbox(
label="Create From Hub", value=False
label="Create From Hub", value=False, visible=False
)
cloud_db_512_model = gr.Checkbox(label="512x Model", value=True)
with gr.Column(visible=False) as hub_row: