fix: dreambooth config save.
parent
d0c095dee5
commit
49b6cd3adf
|
|
@ -6,6 +6,7 @@ import gradio as gr
|
|||
import os
|
||||
import sys
|
||||
import logging
|
||||
import shutil
|
||||
from utils import upload_file_to_s3_by_presign_url
|
||||
from utils import get_variable_from_json
|
||||
|
||||
|
|
@ -80,20 +81,18 @@ def get_cloud_db_model_name_list():
|
|||
model_name_list = [model['model_name'] for model in model_list]
|
||||
return model_name_list
|
||||
|
||||
def hack_db_config(db_config, db_config_file_path, model_name, data_list, class_data_list):
|
||||
def hack_db_config(db_config, db_config_file_path, model_name, data_list, class_data_list, local_model_name):
|
||||
for k in db_config:
|
||||
if k == "model_dir":
|
||||
db_config[k] = re.sub(".+/(models/dreambooth/).+$", f"\\1{model_name}", db_config[k])
|
||||
elif k == "pretrained_model_name_or_path":
|
||||
db_config[k] = re.sub(".+/(models/dreambooth/).+(working)$", f"\\1{model_name}/\\2", db_config[k])
|
||||
elif k == "model_name":
|
||||
db_config[k] = db_config[k].replace("dummy_local_model", model_name)
|
||||
db_config[k] = db_config[k].replace(local_model_name, model_name)
|
||||
elif k == "concepts_list":
|
||||
for concept, data, class_data in zip(db_config[k], data_list, class_data_list):
|
||||
concept["instance_data_dir"] = data
|
||||
concept["class_data_dir"] = class_data
|
||||
# else:
|
||||
# db_config[k] = db_config[k].replace("dummy_local_model", model_name)
|
||||
with open(db_config_file_path, "w") as db_config_file_w:
|
||||
json.dump(db_config, db_config_file_w)
|
||||
|
||||
|
|
@ -194,8 +193,9 @@ def wrap_save_config(model_name):
|
|||
setattr(dreambooth_shared, 'dreambooth_models_path', origin_model_path)
|
||||
|
||||
def cloud_train(
|
||||
local_model_name: str,
|
||||
train_model_name: str,
|
||||
local_model_name=False,
|
||||
db_use_txt2img=False,
|
||||
training_instance_type: str= ""
|
||||
):
|
||||
integral_check = False
|
||||
|
|
@ -210,7 +210,7 @@ def cloud_train(
|
|||
try:
|
||||
# Get data path and class data path.
|
||||
print(f"Start cloud training {train_model_name}")
|
||||
db_config_path = os.path.join("models/dreambooth/dummy_local_model/db_config.json")
|
||||
db_config_path = os.path.join(f"models/dreambooth/{local_model_name}/db_config.json")
|
||||
with open(db_config_path) as db_config_file:
|
||||
config = json.load(db_config_file)
|
||||
local_data_path_list = []
|
||||
|
|
@ -224,7 +224,8 @@ def cloud_train(
|
|||
class_data_path_list.append(concept["class_data_dir"].replace("s3://", "").replace("/", "-").strip("-"))
|
||||
model_list = get_cloud_db_models()
|
||||
new_db_config_path = os.path.join(base_model_folder, f"{train_model_name}/db_config_cloud.json")
|
||||
hack_db_config(config, new_db_config_path, train_model_name, data_path_list, class_data_path_list)
|
||||
print(f"hack config from local_model_name to new_db_config_path")
|
||||
hack_db_config(config, new_db_config_path, train_model_name, data_path_list, class_data_path_list, local_model_name)
|
||||
if config["save_lora_for_extra_net"] == True:
|
||||
model_type = "Lora"
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -4,9 +4,26 @@ function sleep(ms) {
|
|||
return new Promise(resolve => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
function getElementByXpath(path) {
|
||||
console.log(path)
|
||||
return document.evaluate(path, document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue;
|
||||
}
|
||||
|
||||
function set_dropdown_value(xpath, value) {
|
||||
let dropdown = getElementByXpath(xpath)
|
||||
console.log("Trying to click the dropdown " + dropdown)
|
||||
dropdown.click()
|
||||
let target_dropdown = getElementByXpath(`//ul[contains(.,'${value}')]`)
|
||||
console.log("Trying to set the value of dropdown" + target_dropdown)
|
||||
target_dropdown.click()
|
||||
}
|
||||
|
||||
async function db_start_sagemaker_train() {
|
||||
console.log("Sagemaker training");
|
||||
console.log(arguments);
|
||||
// var xpath = "//*[@id='cloud_db_model_name']/label/div/div[1]/div"
|
||||
// var value = "dummy_local_model"
|
||||
// set_dropdown_value(xpath, value)
|
||||
|
||||
// pop up confirmation for sagemaker training
|
||||
let do_save = confirm("Confirm to start Sagemaker training? This will take a while.");
|
||||
|
|
@ -14,10 +31,10 @@ async function db_start_sagemaker_train() {
|
|||
return;
|
||||
}
|
||||
save_config();
|
||||
await sleep(5000);
|
||||
await sleep(1000);
|
||||
// let sagemaker_train = gradioApp().getElementById("db_sagemaker_train");
|
||||
// sagemaker_train.style.display = "block";
|
||||
return filterArgs(3, arguments)
|
||||
return filterArgs(4, arguments)
|
||||
}
|
||||
|
||||
function check_create_model_params() {
|
||||
|
|
|
|||
|
|
@ -117,6 +117,7 @@ def on_after_component_callback(component, **_kwargs):
|
|||
fn=async_cloud_train,
|
||||
_js="db_start_sagemaker_train",
|
||||
inputs=[
|
||||
db_model_name,
|
||||
cloud_db_model_name,
|
||||
db_use_txt2img,
|
||||
cloud_train_instance_type
|
||||
|
|
@ -637,7 +638,7 @@ def ui_tabs_callback():
|
|||
with gr.Row():
|
||||
cloud_train_instance_type = gr.Dropdown(
|
||||
label="SageMaker Train Instance Type",
|
||||
choices=['ml.g4dn.2xlarge'],
|
||||
choices=['ml.g4dn.2xlarge', 'ml.g5.2xlarge'],
|
||||
elem_id="cloud_train_instance_type",
|
||||
info='select SageMaker Train Instance Type'
|
||||
)
|
||||
|
|
@ -697,7 +698,7 @@ def ui_tabs_callback():
|
|||
cloud_db_new_model_name = gr.Textbox(label="Name", placeholder="Model names can only contain alphanumeric and -")
|
||||
with gr.Row():
|
||||
cloud_db_create_from_hub = gr.Checkbox(
|
||||
label="Create From Hub", value=False
|
||||
label="Create From Hub", value=False, visible=False
|
||||
)
|
||||
cloud_db_512_model = gr.Checkbox(label="512x Model", value=True)
|
||||
with gr.Column(visible=False) as hub_row:
|
||||
|
|
|
|||
Loading…
Reference in New Issue