245 lines
14 KiB
Python
245 lines
14 KiB
Python
import gradio as gr
|
|
|
|
from aws_extension.sagemaker_ui_utils import create_refresh_button_by_user
|
|
from dreambooth_on_cloud.create_model import get_sd_cloud_models, get_create_model_job_list, cloud_create_model
|
|
from dreambooth_on_cloud.train import get_cloud_db_model_name_list, get_train_job_list, wrap_load_model_params
|
|
from modules import script_callbacks
|
|
|
|
# create new tabs for create Model
|
|
from modules.ui_common import create_refresh_button
|
|
|
|
origin_callback = script_callbacks.ui_tabs_callback
|
|
|
|
|
|
def avoid_duplicate_from_restart_ui(res):
|
|
for extension_ui in res:
|
|
if extension_ui[1] == 'Dreambooth':
|
|
for key in list(extension_ui[0].blocks):
|
|
val = extension_ui[0].blocks[key]
|
|
if type(val) is gr.Tab:
|
|
if val.label == 'Select From Cloud':
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def get_sorted_lora_cloud_models():
|
|
return []
|
|
|
|
|
|
def get_cloud_model_snapshots():
|
|
return []
|
|
|
|
|
|
def ui_tabs_callback():
|
|
res = origin_callback()
|
|
if avoid_duplicate_from_restart_ui(res):
|
|
return res
|
|
for extension_ui in res:
|
|
if extension_ui[1] == 'Dreambooth':
|
|
for key in list(extension_ui[0].blocks):
|
|
val = extension_ui[0].blocks[key]
|
|
if type(val) is gr.Tab:
|
|
if val.label == 'Select':
|
|
with extension_ui[0]:
|
|
with val.parent:
|
|
with gr.Tab('Select From Cloud'):
|
|
with gr.Row():
|
|
cloud_db_model_name = gr.Dropdown(
|
|
label="Model",
|
|
elem_id="cloud_db_model_name"
|
|
)
|
|
create_refresh_button_by_user(
|
|
cloud_db_model_name,
|
|
lambda *args: None,
|
|
lambda username: {"choices": sorted(get_cloud_db_model_name_list(username))},
|
|
"refresh_db_models",
|
|
)
|
|
with gr.Row():
|
|
cloud_db_snapshot = gr.Dropdown(
|
|
label="Cloud Snapshot to Resume",
|
|
choices=sorted(get_cloud_model_snapshots()),
|
|
elem_id="cloud_snapshot_to_resume_dropdown"
|
|
)
|
|
create_refresh_button(
|
|
cloud_db_snapshot,
|
|
get_cloud_model_snapshots,
|
|
lambda: {"choices": sorted(get_cloud_model_snapshots())},
|
|
"refresh_db_snapshots",
|
|
)
|
|
|
|
with gr.Row():
|
|
cloud_train_instance_type = gr.Dropdown(
|
|
label="SageMaker Train Instance Type",
|
|
choices=['ml.g4dn.2xlarge', 'ml.g5.2xlarge'],
|
|
elem_id="cloud_train_instance_type",
|
|
info='select SageMaker Train Instance Type'
|
|
)
|
|
with gr.Row(visible=False) as lora_model_row:
|
|
cloud_db_lora_model_name = gr.Dropdown(
|
|
label="Lora Model", choices=get_sorted_lora_cloud_models(),
|
|
elem_id="cloud_lora_model_dropdown"
|
|
)
|
|
create_refresh_button(
|
|
cloud_db_lora_model_name,
|
|
get_sorted_lora_cloud_models,
|
|
lambda: {"choices": get_sorted_lora_cloud_models()},
|
|
"refresh_lora_models",
|
|
)
|
|
with gr.Row():
|
|
gr.HTML(value="Loaded Model from Cloud:")
|
|
cloud_db_model_path = gr.HTML()
|
|
with gr.Row():
|
|
gr.HTML(value="Cloud Model Revision:")
|
|
cloud_db_revision = gr.HTML(elem_id="cloud_db_revision")
|
|
with gr.Row():
|
|
gr.HTML(value="Cloud Model Epoch:")
|
|
cloud_db_epochs = gr.HTML(elem_id="cloud_db_epochs")
|
|
with gr.Row():
|
|
gr.HTML(value="V2 Model From Cloud:")
|
|
cloud_db_v2 = gr.HTML(elem_id="cloud_db_v2")
|
|
with gr.Row():
|
|
gr.HTML(value="Has EMA:")
|
|
cloud_db_has_ema = gr.HTML(elem_id="cloud_db_has_ema")
|
|
with gr.Row():
|
|
gr.HTML(value="Source Checkpoint From Cloud:")
|
|
cloud_db_src = gr.HTML()
|
|
with gr.Row():
|
|
gr.HTML(value="Cloud DB Status:")
|
|
cloud_db_status = gr.HTML(elem_id="db_status", value="")
|
|
with gr.Row():
|
|
gr.HTML(value="Experimental Shared Source:")
|
|
cloud_db_shared_diffusers_path = gr.HTML()
|
|
with gr.Row():
|
|
gr.HTML(value="<b>Training Jobs Details:<b/>")
|
|
with gr.Column():
|
|
training_job_dashboard = gr.Dataframe(
|
|
headers=["id", "model name", "status", "SageMaker train name"],
|
|
datatype=["str", "str", "str", "str"],
|
|
col_count=(4, "fixed"),
|
|
interactive=False,
|
|
elem_id='training_job_dashboard'
|
|
)
|
|
train_refresh_btn = gr.Button(
|
|
value="Refresh Train List From Cloud", variant="primary"
|
|
)
|
|
train_refresh_btn.click(fn=get_train_job_list,
|
|
inputs=[],
|
|
outputs=[training_job_dashboard])
|
|
with gr.Tab('Create From Cloud'):
|
|
with gr.Column():
|
|
cloud_db_create_model = gr.Button(
|
|
value="Create Model From Cloud", variant="primary"
|
|
)
|
|
cloud_db_new_model_name = gr.Textbox(label="Name",
|
|
placeholder="Model names can only contain"
|
|
" alphanumeric and -")
|
|
with gr.Row():
|
|
cloud_db_create_from_hub = gr.Checkbox(
|
|
label="Create From Hub", value=False, visible=False
|
|
)
|
|
cloud_db_512_model = gr.Checkbox(label="512x Model", value=True)
|
|
with gr.Column(visible=False) as hub_row:
|
|
cloud_db_new_model_url = gr.Textbox(
|
|
label="Model Path",
|
|
placeholder="runwayml/stable-diffusion-v1-5",
|
|
elem_id="cloud_db_model_path_text_box"
|
|
)
|
|
cloud_db_new_model_token = gr.Textbox(
|
|
label="HuggingFace Token", value=""
|
|
)
|
|
with gr.Column(visible=True) as local_row:
|
|
with gr.Row():
|
|
cloud_db_new_model_src = gr.Dropdown(
|
|
label="Source Checkpoint",
|
|
choices=sorted(get_sd_cloud_models()),
|
|
elem_id="cloud_db_source_checkpoint_dropdown"
|
|
)
|
|
create_refresh_button(
|
|
cloud_db_new_model_src,
|
|
get_sd_cloud_models,
|
|
lambda: {"choices": sorted(get_sd_cloud_models())},
|
|
"refresh_sd_models",
|
|
)
|
|
with gr.Column(visible=False) as shared_row:
|
|
with gr.Row():
|
|
cloud_db_new_model_shared_src = gr.Dropdown(
|
|
label="EXPERIMENTAL: LoRA Shared Diffusers Source",
|
|
choices=[],
|
|
value=""
|
|
)
|
|
cloud_db_new_model_extract_ema = gr.Checkbox(
|
|
label="Extract EMA Weights", value=False
|
|
)
|
|
cloud_db_train_unfrozen = gr.Checkbox(label="Unfreeze Model", value=False,
|
|
elem_id="cloud_db_unfreeze_model_checkbox")
|
|
with gr.Row():
|
|
gr.HTML(value="<b>Model Creation Jobs Details:<b/>")
|
|
with gr.Column():
|
|
createmodel_dashboard = gr.Dataframe(
|
|
headers=["id", "model name", "status"],
|
|
datatype=["str", "str", "str"],
|
|
col_count=(3, "fixed"),
|
|
interactive=False
|
|
)
|
|
|
|
model_refresh_btn = gr.Button(
|
|
value="Refresh Model List From Cloud", variant="primary"
|
|
)
|
|
model_refresh_btn.click(fn=get_create_model_job_list,
|
|
inputs=[],
|
|
outputs=[createmodel_dashboard])
|
|
|
|
def toggle_new_rows(create_from):
|
|
return gr.update(visible=create_from), gr.update(visible=not create_from)
|
|
|
|
cloud_db_create_from_hub.change(
|
|
fn=toggle_new_rows,
|
|
inputs=[cloud_db_create_from_hub],
|
|
outputs=[hub_row, local_row],
|
|
)
|
|
|
|
cloud_db_model_name.change(
|
|
_js="clear_loaded",
|
|
fn=wrap_load_model_params,
|
|
inputs=[cloud_db_model_name],
|
|
outputs=[
|
|
cloud_db_model_path,
|
|
cloud_db_revision,
|
|
cloud_db_epochs,
|
|
cloud_db_v2,
|
|
cloud_db_has_ema,
|
|
cloud_db_src,
|
|
cloud_db_shared_diffusers_path,
|
|
cloud_db_snapshot,
|
|
cloud_db_lora_model_name,
|
|
cloud_db_status,
|
|
],
|
|
)
|
|
cloud_db_create_model.click(
|
|
fn=cloud_create_model,
|
|
_js="check_create_model_params",
|
|
inputs=[
|
|
cloud_db_new_model_name,
|
|
cloud_db_new_model_src,
|
|
cloud_db_new_model_shared_src,
|
|
cloud_db_create_from_hub,
|
|
cloud_db_new_model_url,
|
|
cloud_db_new_model_token,
|
|
cloud_db_new_model_extract_ema,
|
|
cloud_db_train_unfrozen,
|
|
cloud_db_512_model,
|
|
],
|
|
outputs=[
|
|
createmodel_dashboard
|
|
# cloud_db_new_model_name
|
|
# cloud_db_create_from_hub
|
|
# cloud_db_512_model
|
|
# cloud_db_new_model_url
|
|
# cloud_db_new_model_token
|
|
# cloud_db_new_model_src
|
|
]
|
|
)
|
|
break
|
|
return res
|