809 lines
47 KiB
Python
809 lines
47 KiB
Python
import sys
|
|
import requests
|
|
import logging
|
|
import gradio as gr
|
|
import os
|
|
import modules.scripts as scripts
|
|
from modules import script_callbacks
|
|
from modules.ui import create_refresh_button
|
|
from modules.ui_components import FormRow
|
|
from utils import get_variable_from_json
|
|
from utils import save_variable_to_json
|
|
from PIL import Image
|
|
|
|
# sys.path.append("extensions/stable-diffusion-aws-extension/scripts")
|
|
# import sagemaker_ui
|
|
from aws_extension import sagemaker_ui
|
|
|
|
dreambooth_available = True
|
|
def dummy_function(*args, **kwargs):
|
|
return []
|
|
|
|
try:
|
|
from dreambooth_on_cloud.train import (
|
|
async_cloud_train,
|
|
get_cloud_db_model_name_list,
|
|
wrap_load_model_params,
|
|
get_train_job_list,
|
|
get_sorted_cloud_dataset
|
|
)
|
|
from dreambooth_on_cloud.create_model import (
|
|
get_sd_cloud_models,
|
|
get_create_model_job_list,
|
|
cloud_create_model,
|
|
)
|
|
except Exception as e:
|
|
logging.warning("[main]dreambooth_on_cloud is not installed or can not be imported, using dummy function to proceed.")
|
|
dreambooth_available = False
|
|
cloud_train = dummy_function
|
|
get_cloud_db_model_name_list = dummy_function
|
|
wrap_load_model_params = dummy_function
|
|
get_train_job_list = dummy_function
|
|
get_sorted_cloud_dataset = dummy_function
|
|
get_sd_cloud_models = dummy_function
|
|
get_create_model_job_list = dummy_function
|
|
cloud_create_model = dummy_function
|
|
|
|
cloud_datasets = []
|
|
training_job_dashboard = None
|
|
db_model_name = None
|
|
cloud_db_model_name = None
|
|
cloud_train_instance_type = None
|
|
db_use_txt2img = None
|
|
db_sagemaker_train = None
|
|
db_save_config = None
|
|
txt2img_show_hook = None
|
|
txt2img_gallery = None
|
|
txt2img_generation_info = None
|
|
txt2img_html_info = None
|
|
|
|
img2img_show_hook = None
|
|
img2img_gallery = None
|
|
img2img_generation_info = None
|
|
img2img_html_info = None
|
|
modelmerger_merge_hook = None
|
|
modelmerger_merge_component = None
|
|
|
|
async_inference_choices=["ml.g4dn.xlarge","ml.g4dn.2xlarge","ml.g4dn.4xlarge","ml.g4dn.8xlarge","ml.g4dn.12xlarge", \
|
|
"ml.g5.xlarge","ml.g5.2xlarge","ml.g5.4xlarge","ml.g5.8xlarge","ml.g5.12xlarge"]
|
|
|
|
class SageMakerUI(scripts.Script):
|
|
def title(self):
|
|
return "SageMaker embeddings"
|
|
|
|
def show(self, is_img2img):
|
|
return scripts.AlwaysVisible
|
|
|
|
def ui(self, is_img2img):
|
|
sagemaker_endpoint, sd_checkpoint, sd_checkpoint_refresh_button, textual_inversion_dropdown, lora_dropdown, hyperNetwork_dropdown, controlnet_dropdown, inference_job_dropdown, txt2img_inference_job_ids_refresh_button, primary_model_name, secondary_model_name, tertiary_model_name, modelmerger_merge_on_cloud= sagemaker_ui.create_ui(is_img2img)
|
|
return [sagemaker_endpoint, sd_checkpoint, sd_checkpoint_refresh_button, textual_inversion_dropdown, lora_dropdown, hyperNetwork_dropdown, controlnet_dropdown, inference_job_dropdown, txt2img_inference_job_ids_refresh_button, primary_model_name, secondary_model_name, tertiary_model_name, modelmerger_merge_on_cloud]
|
|
def process(self, p, sagemaker_endpoint, sd_checkpoint, sd_checkpoint_refresh_button, textual_inversion_dropdown, lora_dropdown, hyperNetwork_dropdown, controlnet_dropdown, choose_txt2img_inference_job_id, txt2img_inference_job_ids_refresh_button, primary_model_name, secondary_model_name, tertiary_model_name, modelmerger_on_cloud):
|
|
pass
|
|
|
|
def on_after_component_callback(component, **_kwargs):
|
|
global db_model_name, db_use_txt2img, db_sagemaker_train, db_save_config, cloud_db_model_name, cloud_train_instance_type, training_job_dashboard
|
|
is_dreambooth_train = type(component) is gr.Button and getattr(component, 'elem_id', None) == 'db_train'
|
|
is_dreambooth_model_name = type(component) is gr.Dropdown and \
|
|
(getattr(component, 'elem_id', None) == 'model_name' or \
|
|
(getattr(component, 'label', None) == 'Model' and getattr(component.parent.parent.parent.parent, 'elem_id', None) == 'ModelPanel'))
|
|
is_cloud_dreambooth_model_name = type(component) is gr.Dropdown and \
|
|
getattr(component, 'elem_id', None) == 'cloud_db_model_name'
|
|
is_machine_type_for_train = type(component) is gr.Dropdown and \
|
|
getattr(component, 'elem_id', None) == 'cloud_train_instance_type'
|
|
is_dreambooth_use_txt2img = type(component) is gr.Checkbox and getattr(component, 'label', None) == 'Use txt2img'
|
|
is_training_job_dashboard = type(component) is gr.Dataframe and getattr(component, 'elem_id', None) == 'training_job_dashboard'
|
|
is_db_save_config = getattr(component, 'elem_id', None) == 'db_save_config'
|
|
if is_dreambooth_train:
|
|
db_sagemaker_train = gr.Button(value="SageMaker Train", elem_id = "db_sagemaker_train", variant='primary')
|
|
if is_dreambooth_model_name:
|
|
db_model_name = component
|
|
if is_cloud_dreambooth_model_name:
|
|
cloud_db_model_name = component
|
|
if is_training_job_dashboard:
|
|
training_job_dashboard = component
|
|
if is_machine_type_for_train:
|
|
cloud_train_instance_type = component
|
|
if is_dreambooth_use_txt2img:
|
|
db_use_txt2img = component
|
|
if is_db_save_config:
|
|
db_save_config = component
|
|
# After all requiment comment is loaded, add the SageMaker training button click callback function.
|
|
if training_job_dashboard is not None and cloud_train_instance_type is not None and \
|
|
cloud_db_model_name is not None and db_model_name is not None and \
|
|
db_use_txt2img is not None and db_sagemaker_train is not None and \
|
|
(is_dreambooth_train or is_dreambooth_model_name or is_dreambooth_use_txt2img or is_cloud_dreambooth_model_name or is_machine_type_for_train or is_training_job_dashboard):
|
|
db_model_name.value = "dummy_local_model"
|
|
db_sagemaker_train.click(
|
|
fn=async_cloud_train,
|
|
_js="db_start_sagemaker_train",
|
|
inputs=[
|
|
db_model_name,
|
|
cloud_db_model_name,
|
|
db_use_txt2img,
|
|
cloud_train_instance_type
|
|
],
|
|
outputs=[training_job_dashboard]
|
|
)
|
|
# Hook image display logic
|
|
global txt2img_gallery, txt2img_generation_info, txt2img_html_info, txt2img_show_hook, txt2img_prompt
|
|
is_txt2img_gallery = type(component) is gr.Gallery and getattr(component, 'elem_id', None) == 'txt2img_gallery'
|
|
is_txt2img_generation_info = type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'generation_info_txt2img'
|
|
is_txt2img_html_info = type(component) is gr.HTML and getattr(component, 'elem_id', None) == 'html_info_txt2img'
|
|
is_txt2img_prompt = type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'txt2img_prompt'
|
|
if is_txt2img_prompt:
|
|
txt2img_prompt = component
|
|
if is_txt2img_gallery:
|
|
txt2img_gallery = component
|
|
if is_txt2img_generation_info:
|
|
txt2img_generation_info = component
|
|
if is_txt2img_html_info:
|
|
txt2img_html_info = component
|
|
# return test
|
|
if sagemaker_ui.inference_job_dropdown is not None and \
|
|
txt2img_gallery is not None and \
|
|
txt2img_generation_info is not None and \
|
|
txt2img_html_info is not None and \
|
|
txt2img_show_hook is None and \
|
|
txt2img_prompt is not None:
|
|
txt2img_show_hook = "finish"
|
|
sagemaker_ui.inference_job_dropdown.change(
|
|
fn=lambda selected_value: sagemaker_ui.fake_gan(selected_value),
|
|
inputs=[sagemaker_ui.inference_job_dropdown],
|
|
outputs=[txt2img_gallery, txt2img_generation_info, txt2img_html_info, txt2img_prompt]
|
|
)
|
|
|
|
sagemaker_ui.sagemaker_endpoint.change(
|
|
fn=lambda selected_value: sagemaker_ui.displayEndpointInfo(selected_value),
|
|
inputs=[sagemaker_ui.sagemaker_endpoint],
|
|
outputs=[txt2img_html_info]
|
|
)
|
|
# elem_id = getattr(component, "elem_id", None)
|
|
# if elem_id == "generate_on_cloud_with_cloud_config_button":
|
|
sagemaker_ui.generate_on_cloud_button_with_js.click(
|
|
fn=sagemaker_ui.call_txt2img_inference,
|
|
_js="txt2img_config_save",
|
|
inputs=[sagemaker_ui.sagemaker_endpoint],
|
|
outputs=[txt2img_gallery, txt2img_generation_info, txt2img_html_info]
|
|
)
|
|
sagemaker_ui.modelmerger_merge_on_cloud.click(
|
|
fn=sagemaker_ui.modelmerger_on_cloud_func,
|
|
# fn=None,
|
|
_js="txt2img_config_save",
|
|
inputs=[sagemaker_ui.sagemaker_endpoint],
|
|
# inputs=[
|
|
# sagemaker_ui.primary_model_name,
|
|
# sagemaker_ui.secondary_model_name,
|
|
# sagemaker_ui.tertiary_model_name,
|
|
# ],
|
|
outputs=[
|
|
])
|
|
# Hook image display logic
|
|
global img2img_gallery, img2img_generation_info, img2img_html_info, img2img_show_hook, \
|
|
img2img_prompt, \
|
|
init_img, \
|
|
sketch, \
|
|
init_img_with_mask, \
|
|
inpaint_color_sketch, \
|
|
init_img_inpaint, \
|
|
init_mask_inpaint
|
|
is_img2img_gallery = type(component) is gr.Gallery and getattr(component, 'elem_id', None) == 'img2img_gallery'
|
|
is_img2img_generation_info = type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'generation_info_img2img'
|
|
is_img2img_html_info = type(component) is gr.HTML and getattr(component, 'elem_id', None) == 'html_info_img2img'
|
|
|
|
is_img2img_prompt = type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'img2img_prompt'
|
|
is_init_img = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'img2img_image'
|
|
is_sketch = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'img2img_sketch'
|
|
is_init_img_with_mask = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'img2maskimg'
|
|
is_inpaint_color_sketch = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'inpaint_sketch'
|
|
|
|
|
|
is_init_img_inpaint = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'img_inpaint_base'
|
|
is_init_mask_inpaint = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'img_inpaint_mask'
|
|
|
|
if is_img2img_gallery:
|
|
img2img_gallery = component
|
|
if is_img2img_generation_info:
|
|
img2img_generation_info = component
|
|
if is_img2img_html_info:
|
|
img2img_html_info = component
|
|
|
|
if is_img2img_prompt:
|
|
img2img_prompt = component
|
|
if is_init_img:
|
|
init_img = component
|
|
if is_sketch:
|
|
sketch = component
|
|
if is_init_img_with_mask:
|
|
init_img_with_mask = component
|
|
if is_inpaint_color_sketch:
|
|
inpaint_color_sketch = component
|
|
if is_init_img_inpaint:
|
|
init_img_inpaint = component
|
|
if is_init_mask_inpaint:
|
|
init_mask_inpaint = component
|
|
|
|
if sagemaker_ui.inference_job_dropdown is not None and \
|
|
img2img_gallery is not None and \
|
|
img2img_generation_info is not None and \
|
|
img2img_html_info is not None and \
|
|
img2img_show_hook is None and \
|
|
sagemaker_ui.interrogate_clip_on_cloud_button is not None and \
|
|
sagemaker_ui.interrogate_deep_booru_on_cloud_button is not None and\
|
|
img2img_prompt is not None and \
|
|
init_img is not None and \
|
|
sketch is not None and \
|
|
init_img_with_mask is not None and \
|
|
inpaint_color_sketch is not None and \
|
|
init_img_inpaint is not None and \
|
|
init_mask_inpaint is not None:
|
|
img2img_show_hook = "finish"
|
|
sagemaker_ui.inference_job_dropdown.change(
|
|
fn=lambda selected_value: sagemaker_ui.fake_gan(selected_value),
|
|
inputs=[sagemaker_ui.inference_job_dropdown],
|
|
outputs=[img2img_gallery, img2img_generation_info, img2img_html_info, img2img_prompt]
|
|
# outputs=[img2img_gallery, img2img_generation_info, img2img_html_info]
|
|
)
|
|
|
|
sagemaker_ui.interrogate_clip_on_cloud_button.click(
|
|
fn=sagemaker_ui.call_interrogate_clip,
|
|
_js="img2img_config_save",
|
|
inputs=[sagemaker_ui.sagemaker_endpoint, init_img, sketch, init_img_with_mask, inpaint_color_sketch, init_img_inpaint, init_mask_inpaint],
|
|
outputs=[img2img_gallery, img2img_generation_info, img2img_html_info]
|
|
)
|
|
|
|
sagemaker_ui.interrogate_deep_booru_on_cloud_button.click(
|
|
fn=sagemaker_ui.call_interrogate_deepbooru,
|
|
_js="img2img_config_save",
|
|
inputs=[sagemaker_ui.sagemaker_endpoint, init_img, sketch, init_img_with_mask, inpaint_color_sketch, init_img_inpaint, init_mask_inpaint],
|
|
outputs=[img2img_gallery, img2img_generation_info, img2img_html_info]
|
|
)
|
|
sagemaker_ui.generate_on_cloud_button_with_js_img2img.click(
|
|
fn=sagemaker_ui.call_img2img_inference,
|
|
_js="img2img_config_save",
|
|
inputs=[sagemaker_ui.sagemaker_endpoint, init_img, sketch, init_img_with_mask, inpaint_color_sketch, init_img_inpaint, init_mask_inpaint],
|
|
outputs=[img2img_gallery, img2img_generation_info, img2img_html_info]
|
|
)
|
|
|
|
def update_connect_config(api_url, api_token):
|
|
# Check if api_url ends with '/', if not append it
|
|
if not api_url.endswith('/'):
|
|
api_url += '/'
|
|
|
|
save_variable_to_json('api_gateway_url', api_url)
|
|
save_variable_to_json('api_token', api_token)
|
|
global api_gateway_url
|
|
api_gateway_url = get_variable_from_json('api_gateway_url')
|
|
global api_key
|
|
api_key = get_variable_from_json('api_token')
|
|
print(f"update the api_url:{api_gateway_url} and token: {api_key}............")
|
|
sagemaker_ui.init_refresh_resource_list_from_cloud()
|
|
return "config updated to local config!"
|
|
|
|
def test_aws_connect_config(api_url, api_token):
|
|
update_connect_config(api_url, api_token)
|
|
api_url = get_variable_from_json('api_gateway_url')
|
|
api_token = get_variable_from_json('api_token')
|
|
if not api_url.endswith('/'):
|
|
api_url += '/'
|
|
print(f"get the api_url:{api_url} and token: {api_token}............")
|
|
target_url = f'{api_url}inference/test-connection'
|
|
headers = {
|
|
"x-api-key": api_token,
|
|
"Content-Type": "application/json"
|
|
}
|
|
try:
|
|
response = requests.get(target_url,headers=headers) # Assuming sagemaker_ui.server_request is a wrapper around requests
|
|
response.raise_for_status() # Raise an exception if the HTTP request resulted in an error
|
|
r = response.json()
|
|
return "Successfully Connected"
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"Error: Failed to get server request. Details: {e}")
|
|
return "failed to connect to backend server, please check the url and token"
|
|
|
|
def on_ui_tabs():
|
|
import modules.ui
|
|
buildin_model_list = ['AWS JumpStart Model','AWS BedRock Model','Hugging Face Model']
|
|
with gr.Blocks() as sagemaker_interface:
|
|
with gr.Row(equal_height=True, elem_id="aws_sagemaker_ui_row", visible=False):
|
|
sm_load_params = gr.Button(value="Load Settings", elem_id="aws_load_params", visible=False)
|
|
sm_save_params = gr.Button(value="Save Settings", elem_id="aws_save_params", visible=False)
|
|
sm_train_model = gr.Button(value="Train", variant="primary", elem_id="aws_train_model", visible=False)
|
|
sm_generate_checkpoint = gr.Button(value="Generate Ckpt", elem_id="aws_gen_ckpt", visible=False)
|
|
with gr.Row():
|
|
gr.HTML(value="Enter your API URL & Token to start the connection.", elem_id="hint_row")
|
|
with gr.Row():
|
|
with gr.Column(variant="panel", scale=1):
|
|
gr.HTML(value="<u><b>AWS Connection Setting</b></u>")
|
|
global api_gateway_url
|
|
api_gateway_url = get_variable_from_json('api_gateway_url')
|
|
global api_key
|
|
api_key = get_variable_from_json('api_token')
|
|
with gr.Row():
|
|
api_url_textbox = gr.Textbox(value=api_gateway_url, lines=1, placeholder="Please enter API Url of Middle", label="API Url",elem_id="aws_middleware_api")
|
|
def update_api_gateway_url():
|
|
global api_gateway_url
|
|
api_gateway_url = get_variable_from_json('api_gateway_url')
|
|
return api_gateway_url
|
|
# modules.ui.create_refresh_button(api_url_textbox, get_variable_from_json('api_gateway_url'), lambda: {"value": get_variable_from_json('api_gateway_url')}, "refresh_api_gate_way")
|
|
modules.ui.create_refresh_button(api_url_textbox, update_api_gateway_url, lambda: {"value": api_gateway_url}, "refresh_api_gateway_url")
|
|
with gr.Row():
|
|
def update_api_key():
|
|
global api_key
|
|
api_key = get_variable_from_json('api_token')
|
|
return api_key
|
|
api_token_textbox = gr.Textbox(value=api_key, lines=1, placeholder="Please enter API Token", label="API Token", elem_id="aws_middleware_token")
|
|
modules.ui.create_refresh_button(api_token_textbox, update_api_key, lambda: {"value": api_key}, "refresh_api_token")
|
|
|
|
global test_connection_result
|
|
test_connection_result = gr.Label(title="Output");
|
|
aws_connect_button = gr.Button(value="Update Setting", variant='primary',elem_id="aws_config_save")
|
|
aws_connect_button.click(_js="update_auth_settings",
|
|
fn=update_connect_config,
|
|
inputs = [api_url_textbox, api_token_textbox],
|
|
outputs= [test_connection_result])
|
|
aws_test_button = gr.Button(value="Test Connection", variant='primary',elem_id="aws_config_test")
|
|
aws_test_button.click(test_aws_connect_config, inputs = [api_url_textbox, api_token_textbox], outputs=[test_connection_result])
|
|
with gr.Column(variant="panel", scale=1.5):
|
|
gr.HTML(value="<u><b>Cloud Assets Management</b></u>")
|
|
sagemaker_html_log = gr.HTML(elem_id=f'html_log_sagemaker')
|
|
with gr.Accordion("Upload Model to S3", open=False):
|
|
gr.HTML(value="Refresh to select the model to upload to S3")
|
|
exts = (".bin", ".pt", ".safetensors", ".ckpt")
|
|
root_path = os.getcwd()
|
|
model_folders = {
|
|
"ckpt": os.path.join(root_path, "models", "Stable-diffusion"),
|
|
"text": os.path.join(root_path, "embeddings"),
|
|
"lora": os.path.join(root_path, "models", "Lora"),
|
|
"control": os.path.join(root_path, "models", "ControlNet"),
|
|
"hyper": os.path.join(root_path, "models", "hypernetworks"),
|
|
}
|
|
def scan_sd_ckpt():
|
|
model_files = os.listdir(model_folders["ckpt"])
|
|
# filter non-model files not in exts
|
|
model_files = [f for f in model_files if os.path.splitext(f)[1] in exts]
|
|
model_files = [os.path.join(model_folders["ckpt"], f) for f in model_files]
|
|
return model_files
|
|
def scan_textural_inversion_model():
|
|
model_files = os.listdir(model_folders["text"])
|
|
# filter non-model files not in exts
|
|
model_files = [f for f in model_files if os.path.splitext(f)[1] in exts]
|
|
model_files = [os.path.join(model_folders["text"], f) for f in model_files]
|
|
return model_files
|
|
def scan_lora_model():
|
|
model_files = os.listdir(model_folders["lora"])
|
|
# filter non-model files not in exts
|
|
model_files = [f for f in model_files if os.path.splitext(f)[1] in exts]
|
|
model_files = [os.path.join(model_folders["lora"], f) for f in model_files]
|
|
return model_files
|
|
def scan_control_model():
|
|
model_files = os.listdir(model_folders["control"])
|
|
# filter non-model files not in exts
|
|
model_files = [f for f in model_files if os.path.splitext(f)[1] in exts]
|
|
model_files = [os.path.join(model_folders["control"], f) for f in model_files]
|
|
return model_files
|
|
def scan_hypernetwork_model():
|
|
model_files = os.listdir(model_folders["hyper"])
|
|
# filter non-model files not in exts
|
|
model_files = [f for f in model_files if os.path.splitext(f)[1] in exts]
|
|
model_files = [os.path.join(model_folders["hyper"], f) for f in model_files]
|
|
return model_files
|
|
|
|
with FormRow(elem_id="model_upload_form_row_01"):
|
|
sd_checkpoints_path = gr.Dropdown(label="SD Checkpoints", choices=sorted(scan_sd_ckpt()), elem_id="sd_ckpt_dropdown")
|
|
create_refresh_button(sd_checkpoints_path, scan_sd_ckpt, lambda: {"choices": sorted(scan_sd_ckpt())}, "refresh_sd_ckpt")
|
|
|
|
textual_inversion_path = gr.Dropdown(label="Textual Inversion", choices=sorted(scan_textural_inversion_model()),elem_id="textual_inversion_model_dropdown")
|
|
create_refresh_button(textual_inversion_path, scan_textural_inversion_model, lambda: {"choices": sorted(scan_textural_inversion_model())}, "refresh_textual_inversion_model")
|
|
with FormRow(elem_id="model_upload_form_row_02"):
|
|
lora_path = gr.Dropdown(label="LoRA model", choices=sorted(scan_lora_model()), elem_id="lora_model_dropdown")
|
|
create_refresh_button(lora_path, scan_lora_model, lambda: {"choices": sorted(scan_lora_model())}, "refresh_lora_model",)
|
|
|
|
controlnet_model_path = gr.Dropdown(label="ControlNet model", choices=sorted(scan_control_model()), elem_id="controlnet_model_dropdown")
|
|
create_refresh_button(controlnet_model_path, scan_control_model, lambda: {"choices": sorted(scan_control_model())}, "refresh_controlnet_models")
|
|
with FormRow(elem_id="model_upload_form_row_03"):
|
|
hypernetwork_path = gr.Dropdown(label="Hypernetwork", choices=sorted(scan_hypernetwork_model()),elem_id="hyper_model_dropdown")
|
|
create_refresh_button(hypernetwork_path, scan_hypernetwork_model, lambda: {"choices": sorted(scan_hypernetwork_model())}, "refresh_hyper_models")
|
|
|
|
with gr.Row():
|
|
model_update_button = gr.Button(value="Upload Models to Cloud", variant="primary",elem_id="sagemaker_model_update_button", size=(200, 50))
|
|
model_update_button.click(_js="model_update",
|
|
fn=sagemaker_ui.sagemaker_upload_model_s3,
|
|
inputs=[sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path, controlnet_model_path],
|
|
outputs=[test_connection_result, sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path, controlnet_model_path])
|
|
|
|
|
|
with gr.Blocks(title="Deploy New SageMaker Endpoint", variant='panel'):
|
|
gr.HTML(value="<u><b>Deploy New SageMaker Endpoint</b></u>")
|
|
with gr.Row():
|
|
instance_type_dropdown = gr.Dropdown(label="SageMaker Instance Type", choices=async_inference_choices, elem_id="sagemaker_inference_instance_type_textbox", value="ml.g4dn.xlarge")
|
|
instance_count_dropdown = gr.Dropdown(label="Please select Instance count", choices=["1","2","3","4"], elem_id="sagemaker_inference_instance_count_textbox", value="1")
|
|
|
|
with gr.Row():
|
|
sagemaker_deploy_button = gr.Button(value="Deploy", variant='primary',elem_id="sagemaker_deploy_endpoint_buttion")
|
|
sagemaker_deploy_button.click(sagemaker_ui.sagemaker_deploy,
|
|
_js="deploy_endpoint", \
|
|
inputs = [instance_type_dropdown, instance_count_dropdown],
|
|
outputs=[test_connection_result])
|
|
|
|
with gr.Blocks(title="Delete SageMaker Endpoint", variant='panel'):
|
|
gr.HTML(value="<u><b>Delete SageMaker Endpoint</b></u>")
|
|
with gr.Row():
|
|
sagemaker_endpoint_delete_dropdown = gr.Dropdown(choices=sagemaker_ui.sagemaker_endpoints, multiselect=True, label="Select Cloud SageMaker Endpoint")
|
|
modules.ui.create_refresh_button(sagemaker_endpoint_delete_dropdown, sagemaker_ui.update_sagemaker_endpoints, lambda: {"choices": sagemaker_ui.sagemaker_endpoints}, "refresh_sagemaker_endpoints_delete")
|
|
sagemaker_endpoint_delete_button = gr.Button(value="Delete", variant='primary',elem_id="sagemaker_endpoint_delete_button")
|
|
sagemaker_endpoint_delete_button.click(sagemaker_ui.sagemaker_endpoint_delete,
|
|
_js="delete_sagemaker_endpoint", \
|
|
inputs = [sagemaker_endpoint_delete_dropdown],
|
|
outputs=[test_connection_result])
|
|
|
|
with gr.Column(variant="panel", scale=1):
|
|
# TODO: uncomment if implemented, comment since the tab component do not has visible parameter
|
|
# with gr.Blocks(title="Deploy New SageMaker Endpoint", variant='panel', visible=False):
|
|
# gr.HTML(value="<u><b>AWS Model Setting</b></u>", visible=False)
|
|
# with gr.Tab("Select"):
|
|
# gr.HTML(value="AWS Built-in Model", visible=False)
|
|
# model_select_dropdown = gr.Dropdown(buildin_model_list, label="Select Built-In Model", elem_id="aws_select_model", visible=False)
|
|
# with gr.Tab("Create"):
|
|
# gr.HTML(value="AWS Custom Model", visible=False)
|
|
# model_name_textbox = gr.Textbox(value="", lines=1, placeholder="Please enter model name", label="Model Name", visible=False)
|
|
# model_create_button = gr.Button(value="Create Model", variant='primary',elem_id="aws_create_model", visible=False)
|
|
|
|
with gr.Blocks(title="Create AWS dataset", variant='panel'):
|
|
gr.HTML(value="<u><b>AWS Dataset Management</b></u>")
|
|
with gr.Tab("Create"):
|
|
def upload_file(files):
|
|
file_paths = [file.name for file in files]
|
|
return file_paths
|
|
|
|
file_output = gr.File()
|
|
upload_button = gr.UploadButton("Click to Upload a File", file_types=["image", "video"], file_count="multiple")
|
|
upload_button.upload(fn=upload_file, inputs=[upload_button], outputs=[file_output])
|
|
|
|
def create_dataset(files, dataset_name, dataset_desc):
|
|
print(dataset_name)
|
|
dataset_content = []
|
|
file_path_lookup = {}
|
|
for file in files:
|
|
orig_name = file.name.split(os.sep)[-1]
|
|
file_path_lookup[orig_name] = file.name
|
|
dataset_content.append(
|
|
{
|
|
"filename": orig_name,
|
|
"name": orig_name,
|
|
"type": "image",
|
|
"params": {}
|
|
}
|
|
)
|
|
|
|
payload = {
|
|
"dataset_name": dataset_name,
|
|
"content": dataset_content,
|
|
"params": {
|
|
"description": dataset_desc
|
|
}
|
|
}
|
|
|
|
url = get_variable_from_json('api_gateway_url') + '/dataset'
|
|
api_key = get_variable_from_json('api_token')
|
|
|
|
raw_response = requests.post(url=url, json=payload, headers={'x-api-key': api_key})
|
|
raw_response.raise_for_status()
|
|
response = raw_response.json()
|
|
|
|
print(f"Start upload sample files response:\n{response}")
|
|
for filename, presign_url in response['s3PresignUrl'].items():
|
|
file_path = file_path_lookup[filename]
|
|
with open(file_path, 'rb') as f:
|
|
response = requests.put(presign_url, f)
|
|
print(response)
|
|
response.raise_for_status()
|
|
|
|
payload = {
|
|
"dataset_name": dataset_name,
|
|
"status": "Enabled"
|
|
}
|
|
|
|
raw_response = requests.put(url=url, json=payload, headers={'x-api-key': api_key})
|
|
raw_response.raise_for_status()
|
|
print(raw_response.json())
|
|
return f'Complete Dataset {dataset_name} creation', None, None, None, None
|
|
|
|
dataset_name_upload = gr.Textbox(value="", lines=1, placeholder="Please input dataset name", label="Dataset Name",elem_id="sd_dataset_name_textbox")
|
|
dataset_description_upload = gr.Textbox(value="", lines=1, placeholder="Please input dataset description", label="Dataset Description",elem_id="sd_dataset_description_textbox")
|
|
create_dataset_button = gr.Button("Create Dataset", variant="primary", elem_id="sagemaker_dataset_create_button") # size=(200, 50)
|
|
dataset_create_result = gr.Textbox(value="", label="Create Result", interactive=False)
|
|
create_dataset_button.click(
|
|
fn=create_dataset,
|
|
inputs=[upload_button, dataset_name_upload, dataset_description_upload],
|
|
outputs=[
|
|
dataset_create_result,
|
|
dataset_name_upload,
|
|
dataset_description_upload,
|
|
file_output,
|
|
upload_button
|
|
],
|
|
show_progress=True
|
|
)
|
|
|
|
with gr.Tab('Browse'):
|
|
with gr.Row():
|
|
global cloud_datasets
|
|
cloud_datasets = get_sorted_cloud_dataset()
|
|
|
|
cloud_dataset_name = gr.Dropdown(
|
|
label="Dataset From Cloud",
|
|
choices=[d['datasetName'] for d in cloud_datasets],
|
|
elem_id="cloud_dataset_dropdown",
|
|
type="index",
|
|
info='select datasets from cloud'
|
|
)
|
|
|
|
def refresh_datasets():
|
|
global cloud_datasets
|
|
cloud_datasets = get_sorted_cloud_dataset()
|
|
return cloud_datasets
|
|
|
|
def refresh_datasets_dropdown():
|
|
global cloud_datasets
|
|
cloud_datasets = get_sorted_cloud_dataset()
|
|
return {"choices": [d['datasetName'] for d in cloud_datasets]}
|
|
|
|
create_refresh_button(
|
|
cloud_dataset_name,
|
|
refresh_datasets,
|
|
refresh_datasets_dropdown,
|
|
"refresh_cloud_dataset",
|
|
)
|
|
with gr.Row():
|
|
dataset_s3_output = gr.Textbox(label='dataset s3 location', show_label=True, type='text').style(show_copy_button=True)
|
|
with gr.Row():
|
|
dataset_des_output = gr.Textbox(label='dataset description', show_label=True, type='text')
|
|
with gr.Row():
|
|
dataset_gallery = gr.Gallery(
|
|
label="Dataset images", show_label=False, elem_id="gallery",
|
|
).style(columns=[2], rows=[2], object_fit="contain", height="auto")
|
|
|
|
def get_results_from_datasets(dataset_idx):
|
|
ds = cloud_datasets[dataset_idx]
|
|
|
|
url = f"{get_variable_from_json('api_gateway_url')}/dataset/{ds['datasetName']}/data"
|
|
api_key = get_variable_from_json('api_token')
|
|
raw_response = requests.get(url=url, headers={'x-api-key': api_key})
|
|
raw_response.raise_for_status()
|
|
dataset_items = [ (Image.open(requests.get(item['preview_url'], stream=True).raw), item['key']) for item in raw_response.json()['data']]
|
|
return ds['s3'], ds['description'], dataset_items
|
|
|
|
cloud_dataset_name.select(fn=get_results_from_datasets, inputs=[cloud_dataset_name], outputs=[dataset_s3_output, dataset_des_output, dataset_gallery])
|
|
|
|
|
|
|
|
return (sagemaker_interface, "Amazon SageMaker", "sagemaker_interface"),
|
|
|
|
|
|
script_callbacks.on_after_component(on_after_component_callback)
|
|
script_callbacks.on_ui_tabs(on_ui_tabs)
|
|
# create new tabs for create Model
|
|
origin_callback = script_callbacks.ui_tabs_callback
|
|
|
|
def avoid_duplicate_from_restart_ui(res):
|
|
for extension_ui in res:
|
|
if extension_ui[1] == 'Dreambooth':
|
|
for key in list(extension_ui[0].blocks):
|
|
val = extension_ui[0].blocks[key]
|
|
if type(val) is gr.Tab:
|
|
if val.label == 'Select From Cloud':
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
def ui_tabs_callback():
|
|
res = origin_callback()
|
|
if avoid_duplicate_from_restart_ui(res):
|
|
return res
|
|
for extension_ui in res:
|
|
if extension_ui[1] == 'Dreambooth':
|
|
for key in list(extension_ui[0].blocks):
|
|
val = extension_ui[0].blocks[key]
|
|
if type(val) is gr.Tab:
|
|
if val.label == 'Select':
|
|
with extension_ui[0]:
|
|
with val.parent:
|
|
with gr.Tab('Select From Cloud'):
|
|
with gr.Row():
|
|
cloud_db_model_name = gr.Dropdown(
|
|
label="Model", choices=sorted(get_cloud_db_model_name_list()),
|
|
elem_id="cloud_db_model_name"
|
|
)
|
|
create_refresh_button(
|
|
cloud_db_model_name,
|
|
get_cloud_db_model_name_list,
|
|
lambda: {"choices": sorted(get_cloud_db_model_name_list())},
|
|
"refresh_db_models",
|
|
)
|
|
with gr.Row():
|
|
cloud_db_snapshot = gr.Dropdown(
|
|
label="Cloud Snapshot to Resume",
|
|
choices=sorted(get_cloud_model_snapshots()),
|
|
elem_id="cloud_snapshot_to_resume_dropdown"
|
|
)
|
|
create_refresh_button(
|
|
cloud_db_snapshot,
|
|
get_cloud_model_snapshots,
|
|
lambda: {"choices": sorted(get_cloud_model_snapshots())},
|
|
"refresh_db_snapshots",
|
|
)
|
|
|
|
with gr.Row():
|
|
cloud_train_instance_type = gr.Dropdown(
|
|
label="SageMaker Train Instance Type",
|
|
choices=['ml.g4dn.2xlarge', 'ml.g5.2xlarge'],
|
|
elem_id="cloud_train_instance_type",
|
|
info='select SageMaker Train Instance Type'
|
|
)
|
|
with gr.Row(visible=False) as lora_model_row:
|
|
cloud_db_lora_model_name = gr.Dropdown(
|
|
label="Lora Model", choices=get_sorted_lora_cloud_models(),
|
|
elem_id="cloud_lora_model_dropdown"
|
|
)
|
|
create_refresh_button(
|
|
cloud_db_lora_model_name,
|
|
get_sorted_lora_cloud_models,
|
|
lambda: {"choices": get_sorted_lora_cloud_models()},
|
|
"refresh_lora_models",
|
|
)
|
|
with gr.Row():
|
|
gr.HTML(value="Loaded Model from Cloud:")
|
|
cloud_db_model_path = gr.HTML()
|
|
with gr.Row():
|
|
gr.HTML(value="Cloud Model Revision:")
|
|
cloud_db_revision = gr.HTML(elem_id="cloud_db_revision")
|
|
with gr.Row():
|
|
gr.HTML(value="Cloud Model Epoch:")
|
|
cloud_db_epochs = gr.HTML(elem_id="cloud_db_epochs")
|
|
with gr.Row():
|
|
gr.HTML(value="V2 Model From Cloud:")
|
|
cloud_db_v2 = gr.HTML(elem_id="cloud_db_v2")
|
|
with gr.Row():
|
|
gr.HTML(value="Has EMA:")
|
|
cloud_db_has_ema = gr.HTML(elem_id="cloud_db_has_ema")
|
|
with gr.Row():
|
|
gr.HTML(value="Source Checkpoint From Cloud:")
|
|
cloud_db_src = gr.HTML()
|
|
with gr.Row():
|
|
gr.HTML(value="Cloud DB Status:")
|
|
cloud_db_status = gr.HTML(elem_id="db_status", value="")
|
|
with gr.Row():
|
|
gr.HTML(value="Experimental Shared Source:")
|
|
cloud_db_shared_diffusers_path = gr.HTML()
|
|
with gr.Row():
|
|
gr.HTML(value="<b>Training Jobs Details:<b/>")
|
|
with gr.Row():
|
|
training_job_dashboard = gr.Dataframe(
|
|
headers=["id", "model name", "status", "SageMaker train name"],
|
|
datatype=["str", "str", "str", "str"],
|
|
col_count=(4, "fixed"),
|
|
value=get_train_job_list,
|
|
interactive=False,
|
|
every=3,
|
|
elem_id='training_job_dashboard'
|
|
# show_progress=True
|
|
)
|
|
with gr.Tab('Create From Cloud'):
|
|
with gr.Column():
|
|
cloud_db_create_model = gr.Button(
|
|
value="Create Model From Cloud", variant="primary"
|
|
)
|
|
cloud_db_new_model_name = gr.Textbox(label="Name", placeholder="Model names can only contain alphanumeric and -")
|
|
with gr.Row():
|
|
cloud_db_create_from_hub = gr.Checkbox(
|
|
label="Create From Hub", value=False, visible=False
|
|
)
|
|
cloud_db_512_model = gr.Checkbox(label="512x Model", value=True)
|
|
with gr.Column(visible=False) as hub_row:
|
|
cloud_db_new_model_url = gr.Textbox(
|
|
label="Model Path",
|
|
placeholder="runwayml/stable-diffusion-v1-5",
|
|
elem_id="cloud_db_model_path_text_box"
|
|
)
|
|
cloud_db_new_model_token = gr.Textbox(
|
|
label="HuggingFace Token", value=""
|
|
)
|
|
with gr.Column(visible=True) as local_row:
|
|
with gr.Row():
|
|
cloud_db_new_model_src = gr.Dropdown(
|
|
label="Source Checkpoint",
|
|
choices=sorted(get_sd_cloud_models()),
|
|
elem_id="cloud_db_source_checkpoint_dropdown"
|
|
)
|
|
create_refresh_button(
|
|
cloud_db_new_model_src,
|
|
get_sd_cloud_models,
|
|
lambda: {"choices": sorted(get_sd_cloud_models())},
|
|
"refresh_sd_models",
|
|
)
|
|
with gr.Column(visible=False) as shared_row:
|
|
with gr.Row():
|
|
cloud_db_new_model_shared_src = gr.Dropdown(
|
|
label="EXPERIMENTAL: LoRA Shared Diffusers Source",
|
|
choices=[],
|
|
value=""
|
|
)
|
|
cloud_db_new_model_extract_ema = gr.Checkbox(
|
|
label="Extract EMA Weights", value=False
|
|
)
|
|
cloud_db_train_unfrozen = gr.Checkbox(label="Unfreeze Model", value=False, elem_id="cloud_db_unfreeze_model_checkbox")
|
|
with gr.Row():
|
|
gr.HTML(value="<b>Model Creation Jobs Details:<b/>")
|
|
with gr.Row():
|
|
createmodel_dashboard = gr.Dataframe(
|
|
headers=["id", "model name", "status"],
|
|
datatype=["str", "str", "str"],
|
|
col_count=(3, "fixed"),
|
|
value=get_create_model_job_list,
|
|
interactive=False,
|
|
every=3
|
|
# show_progress=True
|
|
)
|
|
|
|
def toggle_new_rows(create_from):
|
|
return gr.update(visible=create_from), gr.update(visible=not create_from)
|
|
|
|
cloud_db_create_from_hub.change(
|
|
fn=toggle_new_rows,
|
|
inputs=[cloud_db_create_from_hub],
|
|
outputs=[hub_row, local_row],
|
|
)
|
|
|
|
cloud_db_model_name.change(
|
|
_js="clear_loaded",
|
|
fn=wrap_load_model_params,
|
|
inputs=[cloud_db_model_name],
|
|
outputs=[
|
|
cloud_db_model_path,
|
|
cloud_db_revision,
|
|
cloud_db_epochs,
|
|
cloud_db_v2,
|
|
cloud_db_has_ema,
|
|
cloud_db_src,
|
|
cloud_db_shared_diffusers_path,
|
|
cloud_db_snapshot,
|
|
cloud_db_lora_model_name,
|
|
cloud_db_status,
|
|
],
|
|
)
|
|
cloud_db_create_model.click(
|
|
fn=cloud_create_model,
|
|
_js="check_create_model_params",
|
|
inputs=[
|
|
cloud_db_new_model_name,
|
|
cloud_db_new_model_src,
|
|
cloud_db_new_model_shared_src,
|
|
cloud_db_create_from_hub,
|
|
cloud_db_new_model_url,
|
|
cloud_db_new_model_token,
|
|
cloud_db_new_model_extract_ema,
|
|
cloud_db_train_unfrozen,
|
|
cloud_db_512_model,
|
|
],
|
|
outputs=[
|
|
createmodel_dashboard
|
|
# cloud_db_new_model_name
|
|
# cloud_db_create_from_hub
|
|
# cloud_db_512_model
|
|
# cloud_db_new_model_url
|
|
# cloud_db_new_model_token
|
|
# cloud_db_new_model_src
|
|
]
|
|
)
|
|
break
|
|
return res
|
|
|
|
script_callbacks.ui_tabs_callback = ui_tabs_callback
|
|
|
|
def get_sorted_lora_cloud_models():
|
|
return []
|
|
|
|
def get_cloud_model_snapshots():
|
|
return [] |