import sys import requests import logging import gradio as gr import os import modules.scripts as scripts from modules import script_callbacks from modules.ui import create_refresh_button from modules.ui_components import FormRow from utils import get_variable_from_json from utils import save_variable_to_json from PIL import Image # sys.path.append("extensions/stable-diffusion-aws-extension/scripts") # import sagemaker_ui from aws_extension import sagemaker_ui dreambooth_available = True def dummy_function(*args, **kwargs): return [] try: from dreambooth_on_cloud.train import ( async_cloud_train, get_cloud_db_model_name_list, wrap_load_model_params, get_train_job_list, get_sorted_cloud_dataset ) from dreambooth_on_cloud.create_model import ( get_sd_cloud_models, get_create_model_job_list, cloud_create_model, ) except Exception as e: logging.warning("[main]dreambooth_on_cloud is not installed or can not be imported, using dummy function to proceed.") dreambooth_available = False cloud_train = dummy_function get_cloud_db_model_name_list = dummy_function wrap_load_model_params = dummy_function get_train_job_list = dummy_function get_sorted_cloud_dataset = dummy_function get_sd_cloud_models = dummy_function get_create_model_job_list = dummy_function cloud_create_model = dummy_function cloud_datasets = [] training_job_dashboard = None db_model_name = None cloud_db_model_name = None cloud_train_instance_type = None db_use_txt2img = None db_sagemaker_train = None db_save_config = None txt2img_show_hook = None txt2img_gallery = None txt2img_generation_info = None txt2img_html_info = None img2img_show_hook = None img2img_gallery = None img2img_generation_info = None img2img_html_info = None modelmerger_merge_hook = None modelmerger_merge_component = None async_inference_choices=["ml.g4dn.xlarge","ml.g4dn.2xlarge","ml.g4dn.4xlarge","ml.g4dn.8xlarge","ml.g4dn.12xlarge", \ "ml.g5.xlarge","ml.g5.2xlarge","ml.g5.4xlarge","ml.g5.8xlarge","ml.g5.12xlarge"] class SageMakerUI(scripts.Script): def title(self): return "SageMaker embeddings" def show(self, is_img2img): return scripts.AlwaysVisible def ui(self, is_img2img): sagemaker_endpoint, sd_checkpoint, sd_checkpoint_refresh_button, textual_inversion_dropdown, lora_dropdown, hyperNetwork_dropdown, controlnet_dropdown, inference_job_dropdown, txt2img_inference_job_ids_refresh_button, primary_model_name, secondary_model_name, tertiary_model_name, modelmerger_merge_on_cloud= sagemaker_ui.create_ui(is_img2img) return [sagemaker_endpoint, sd_checkpoint, sd_checkpoint_refresh_button, textual_inversion_dropdown, lora_dropdown, hyperNetwork_dropdown, controlnet_dropdown, inference_job_dropdown, txt2img_inference_job_ids_refresh_button, primary_model_name, secondary_model_name, tertiary_model_name, modelmerger_merge_on_cloud] def process(self, p, sagemaker_endpoint, sd_checkpoint, sd_checkpoint_refresh_button, textual_inversion_dropdown, lora_dropdown, hyperNetwork_dropdown, controlnet_dropdown, choose_txt2img_inference_job_id, txt2img_inference_job_ids_refresh_button, primary_model_name, secondary_model_name, tertiary_model_name, modelmerger_on_cloud): pass def on_after_component_callback(component, **_kwargs): global db_model_name, db_use_txt2img, db_sagemaker_train, db_save_config, cloud_db_model_name, cloud_train_instance_type, training_job_dashboard is_dreambooth_train = type(component) is gr.Button and getattr(component, 'elem_id', None) == 'db_train' is_dreambooth_model_name = type(component) is gr.Dropdown and \ (getattr(component, 'elem_id', None) == 'model_name' or \ (getattr(component, 'label', None) == 'Model' and getattr(component.parent.parent.parent.parent, 'elem_id', None) == 'ModelPanel')) is_cloud_dreambooth_model_name = type(component) is gr.Dropdown and \ getattr(component, 'elem_id', None) == 'cloud_db_model_name' is_machine_type_for_train = type(component) is gr.Dropdown and \ getattr(component, 'elem_id', None) == 'cloud_train_instance_type' is_dreambooth_use_txt2img = type(component) is gr.Checkbox and getattr(component, 'label', None) == 'Use txt2img' is_training_job_dashboard = type(component) is gr.Dataframe and getattr(component, 'elem_id', None) == 'training_job_dashboard' is_db_save_config = getattr(component, 'elem_id', None) == 'db_save_config' if is_dreambooth_train: db_sagemaker_train = gr.Button(value="SageMaker Train", elem_id = "db_sagemaker_train", variant='primary') if is_dreambooth_model_name: db_model_name = component if is_cloud_dreambooth_model_name: cloud_db_model_name = component if is_training_job_dashboard: training_job_dashboard = component if is_machine_type_for_train: cloud_train_instance_type = component if is_dreambooth_use_txt2img: db_use_txt2img = component if is_db_save_config: db_save_config = component # After all requiment comment is loaded, add the SageMaker training button click callback function. if training_job_dashboard is not None and cloud_train_instance_type is not None and \ cloud_db_model_name is not None and db_model_name is not None and \ db_use_txt2img is not None and db_sagemaker_train is not None and \ (is_dreambooth_train or is_dreambooth_model_name or is_dreambooth_use_txt2img or is_cloud_dreambooth_model_name or is_machine_type_for_train or is_training_job_dashboard): db_model_name.value = "dummy_local_model" db_sagemaker_train.click( fn=async_cloud_train, _js="db_start_sagemaker_train", inputs=[ db_model_name, cloud_db_model_name, db_use_txt2img, cloud_train_instance_type ], outputs=[training_job_dashboard] ) # Hook image display logic global txt2img_gallery, txt2img_generation_info, txt2img_html_info, txt2img_show_hook, txt2img_prompt is_txt2img_gallery = type(component) is gr.Gallery and getattr(component, 'elem_id', None) == 'txt2img_gallery' is_txt2img_generation_info = type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'generation_info_txt2img' is_txt2img_html_info = type(component) is gr.HTML and getattr(component, 'elem_id', None) == 'html_info_txt2img' is_txt2img_prompt = type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'txt2img_prompt' if is_txt2img_prompt: txt2img_prompt = component if is_txt2img_gallery: txt2img_gallery = component if is_txt2img_generation_info: txt2img_generation_info = component if is_txt2img_html_info: txt2img_html_info = component # return test if sagemaker_ui.inference_job_dropdown is not None and \ txt2img_gallery is not None and \ txt2img_generation_info is not None and \ txt2img_html_info is not None and \ txt2img_show_hook is None and \ txt2img_prompt is not None: txt2img_show_hook = "finish" sagemaker_ui.inference_job_dropdown.change( fn=lambda selected_value: sagemaker_ui.fake_gan(selected_value), inputs=[sagemaker_ui.inference_job_dropdown], outputs=[txt2img_gallery, txt2img_generation_info, txt2img_html_info, txt2img_prompt] ) sagemaker_ui.sagemaker_endpoint.change( fn=lambda selected_value: sagemaker_ui.displayEndpointInfo(selected_value), inputs=[sagemaker_ui.sagemaker_endpoint], outputs=[txt2img_html_info] ) # elem_id = getattr(component, "elem_id", None) # if elem_id == "generate_on_cloud_with_cloud_config_button": sagemaker_ui.generate_on_cloud_button_with_js.click( fn=sagemaker_ui.call_txt2img_inference, _js="txt2img_config_save", inputs=[sagemaker_ui.sagemaker_endpoint], outputs=[txt2img_gallery, txt2img_generation_info, txt2img_html_info] ) sagemaker_ui.modelmerger_merge_on_cloud.click( fn=sagemaker_ui.modelmerger_on_cloud_func, # fn=None, _js="txt2img_config_save", inputs=[sagemaker_ui.sagemaker_endpoint], # inputs=[ # sagemaker_ui.primary_model_name, # sagemaker_ui.secondary_model_name, # sagemaker_ui.tertiary_model_name, # ], outputs=[ ]) # Hook image display logic global img2img_gallery, img2img_generation_info, img2img_html_info, img2img_show_hook, \ img2img_prompt, \ init_img, \ sketch, \ init_img_with_mask, \ inpaint_color_sketch, \ init_img_inpaint, \ init_mask_inpaint is_img2img_gallery = type(component) is gr.Gallery and getattr(component, 'elem_id', None) == 'img2img_gallery' is_img2img_generation_info = type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'generation_info_img2img' is_img2img_html_info = type(component) is gr.HTML and getattr(component, 'elem_id', None) == 'html_info_img2img' is_img2img_prompt = type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'img2img_prompt' is_init_img = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'img2img_image' is_sketch = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'img2img_sketch' is_init_img_with_mask = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'img2maskimg' is_inpaint_color_sketch = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'inpaint_sketch' is_init_img_inpaint = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'img_inpaint_base' is_init_mask_inpaint = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'img_inpaint_mask' if is_img2img_gallery: img2img_gallery = component if is_img2img_generation_info: img2img_generation_info = component if is_img2img_html_info: img2img_html_info = component if is_img2img_prompt: img2img_prompt = component if is_init_img: init_img = component if is_sketch: sketch = component if is_init_img_with_mask: init_img_with_mask = component if is_inpaint_color_sketch: inpaint_color_sketch = component if is_init_img_inpaint: init_img_inpaint = component if is_init_mask_inpaint: init_mask_inpaint = component if sagemaker_ui.inference_job_dropdown is not None and \ img2img_gallery is not None and \ img2img_generation_info is not None and \ img2img_html_info is not None and \ img2img_show_hook is None and \ sagemaker_ui.interrogate_clip_on_cloud_button is not None and \ sagemaker_ui.interrogate_deep_booru_on_cloud_button is not None and\ img2img_prompt is not None and \ init_img is not None and \ sketch is not None and \ init_img_with_mask is not None and \ inpaint_color_sketch is not None and \ init_img_inpaint is not None and \ init_mask_inpaint is not None: img2img_show_hook = "finish" sagemaker_ui.inference_job_dropdown.change( fn=lambda selected_value: sagemaker_ui.fake_gan(selected_value), inputs=[sagemaker_ui.inference_job_dropdown], outputs=[img2img_gallery, img2img_generation_info, img2img_html_info, img2img_prompt] # outputs=[img2img_gallery, img2img_generation_info, img2img_html_info] ) sagemaker_ui.interrogate_clip_on_cloud_button.click( fn=sagemaker_ui.call_interrogate_clip, _js="img2img_config_save", inputs=[sagemaker_ui.sagemaker_endpoint, init_img, sketch, init_img_with_mask, inpaint_color_sketch, init_img_inpaint, init_mask_inpaint], outputs=[img2img_gallery, img2img_generation_info, img2img_html_info] ) sagemaker_ui.interrogate_deep_booru_on_cloud_button.click( fn=sagemaker_ui.call_interrogate_deepbooru, _js="img2img_config_save", inputs=[sagemaker_ui.sagemaker_endpoint, init_img, sketch, init_img_with_mask, inpaint_color_sketch, init_img_inpaint, init_mask_inpaint], outputs=[img2img_gallery, img2img_generation_info, img2img_html_info] ) sagemaker_ui.generate_on_cloud_button_with_js_img2img.click( fn=sagemaker_ui.call_img2img_inference, _js="img2img_config_save", inputs=[sagemaker_ui.sagemaker_endpoint, init_img, sketch, init_img_with_mask, inpaint_color_sketch, init_img_inpaint, init_mask_inpaint], outputs=[img2img_gallery, img2img_generation_info, img2img_html_info] ) def update_connect_config(api_url, api_token): # Check if api_url ends with '/', if not append it if not api_url.endswith('/'): api_url += '/' save_variable_to_json('api_gateway_url', api_url) save_variable_to_json('api_token', api_token) global api_gateway_url api_gateway_url = get_variable_from_json('api_gateway_url') global api_key api_key = get_variable_from_json('api_token') print(f"update the api_url:{api_gateway_url} and token: {api_key}............") sagemaker_ui.init_refresh_resource_list_from_cloud() return "config updated to local config!" def test_aws_connect_config(api_url, api_token): update_connect_config(api_url, api_token) api_url = get_variable_from_json('api_gateway_url') api_token = get_variable_from_json('api_token') if not api_url.endswith('/'): api_url += '/' print(f"get the api_url:{api_url} and token: {api_token}............") target_url = f'{api_url}inference/test-connection' headers = { "x-api-key": api_token, "Content-Type": "application/json" } try: response = requests.get(target_url,headers=headers) # Assuming sagemaker_ui.server_request is a wrapper around requests response.raise_for_status() # Raise an exception if the HTTP request resulted in an error r = response.json() return "Successfully Connected" except requests.exceptions.RequestException as e: print(f"Error: Failed to get server request. Details: {e}") return "failed to connect to backend server, please check the url and token" def on_ui_tabs(): import modules.ui buildin_model_list = ['AWS JumpStart Model','AWS BedRock Model','Hugging Face Model'] with gr.Blocks() as sagemaker_interface: with gr.Row(equal_height=True, elem_id="aws_sagemaker_ui_row", visible=False): sm_load_params = gr.Button(value="Load Settings", elem_id="aws_load_params", visible=False) sm_save_params = gr.Button(value="Save Settings", elem_id="aws_save_params", visible=False) sm_train_model = gr.Button(value="Train", variant="primary", elem_id="aws_train_model", visible=False) sm_generate_checkpoint = gr.Button(value="Generate Ckpt", elem_id="aws_gen_ckpt", visible=False) with gr.Row(): gr.HTML(value="Enter your API URL & Token to start the connection.", elem_id="hint_row") with gr.Row(): with gr.Column(variant="panel", scale=1): gr.HTML(value="AWS Connection Setting") global api_gateway_url api_gateway_url = get_variable_from_json('api_gateway_url') global api_key api_key = get_variable_from_json('api_token') with gr.Row(): api_url_textbox = gr.Textbox(value=api_gateway_url, lines=1, placeholder="Please enter API Url of Middle", label="API Url",elem_id="aws_middleware_api") def update_api_gateway_url(): global api_gateway_url api_gateway_url = get_variable_from_json('api_gateway_url') return api_gateway_url # modules.ui.create_refresh_button(api_url_textbox, get_variable_from_json('api_gateway_url'), lambda: {"value": get_variable_from_json('api_gateway_url')}, "refresh_api_gate_way") modules.ui.create_refresh_button(api_url_textbox, update_api_gateway_url, lambda: {"value": api_gateway_url}, "refresh_api_gateway_url") with gr.Row(): def update_api_key(): global api_key api_key = get_variable_from_json('api_token') return api_key api_token_textbox = gr.Textbox(value=api_key, lines=1, placeholder="Please enter API Token", label="API Token", elem_id="aws_middleware_token") modules.ui.create_refresh_button(api_token_textbox, update_api_key, lambda: {"value": api_key}, "refresh_api_token") global test_connection_result test_connection_result = gr.Label(title="Output"); aws_connect_button = gr.Button(value="Update Setting", variant='primary',elem_id="aws_config_save") aws_connect_button.click(_js="update_auth_settings", fn=update_connect_config, inputs = [api_url_textbox, api_token_textbox], outputs= [test_connection_result]) aws_test_button = gr.Button(value="Test Connection", variant='primary',elem_id="aws_config_test") aws_test_button.click(test_aws_connect_config, inputs = [api_url_textbox, api_token_textbox], outputs=[test_connection_result]) with gr.Column(variant="panel", scale=1.5): gr.HTML(value="Cloud Assets Management") sagemaker_html_log = gr.HTML(elem_id=f'html_log_sagemaker') with gr.Accordion("Upload Model to S3", open=False): gr.HTML(value="Refresh to select the model to upload to S3") exts = (".bin", ".pt", ".safetensors", ".ckpt") root_path = os.getcwd() model_folders = { "ckpt": os.path.join(root_path, "models", "Stable-diffusion"), "text": os.path.join(root_path, "embeddings"), "lora": os.path.join(root_path, "models", "Lora"), "control": os.path.join(root_path, "models", "ControlNet"), "hyper": os.path.join(root_path, "models", "hypernetworks"), } def scan_sd_ckpt(): model_files = os.listdir(model_folders["ckpt"]) # filter non-model files not in exts model_files = [f for f in model_files if os.path.splitext(f)[1] in exts] model_files = [os.path.join(model_folders["ckpt"], f) for f in model_files] return model_files def scan_textural_inversion_model(): model_files = os.listdir(model_folders["text"]) # filter non-model files not in exts model_files = [f for f in model_files if os.path.splitext(f)[1] in exts] model_files = [os.path.join(model_folders["text"], f) for f in model_files] return model_files def scan_lora_model(): model_files = os.listdir(model_folders["lora"]) # filter non-model files not in exts model_files = [f for f in model_files if os.path.splitext(f)[1] in exts] model_files = [os.path.join(model_folders["lora"], f) for f in model_files] return model_files def scan_control_model(): model_files = os.listdir(model_folders["control"]) # filter non-model files not in exts model_files = [f for f in model_files if os.path.splitext(f)[1] in exts] model_files = [os.path.join(model_folders["control"], f) for f in model_files] return model_files def scan_hypernetwork_model(): model_files = os.listdir(model_folders["hyper"]) # filter non-model files not in exts model_files = [f for f in model_files if os.path.splitext(f)[1] in exts] model_files = [os.path.join(model_folders["hyper"], f) for f in model_files] return model_files with FormRow(elem_id="model_upload_form_row_01"): sd_checkpoints_path = gr.Dropdown(label="SD Checkpoints", choices=sorted(scan_sd_ckpt()), elem_id="sd_ckpt_dropdown") create_refresh_button(sd_checkpoints_path, scan_sd_ckpt, lambda: {"choices": sorted(scan_sd_ckpt())}, "refresh_sd_ckpt") textual_inversion_path = gr.Dropdown(label="Textual Inversion", choices=sorted(scan_textural_inversion_model()),elem_id="textual_inversion_model_dropdown") create_refresh_button(textual_inversion_path, scan_textural_inversion_model, lambda: {"choices": sorted(scan_textural_inversion_model())}, "refresh_textual_inversion_model") with FormRow(elem_id="model_upload_form_row_02"): lora_path = gr.Dropdown(label="LoRA model", choices=sorted(scan_lora_model()), elem_id="lora_model_dropdown") create_refresh_button(lora_path, scan_lora_model, lambda: {"choices": sorted(scan_lora_model())}, "refresh_lora_model",) controlnet_model_path = gr.Dropdown(label="ControlNet model", choices=sorted(scan_control_model()), elem_id="controlnet_model_dropdown") create_refresh_button(controlnet_model_path, scan_control_model, lambda: {"choices": sorted(scan_control_model())}, "refresh_controlnet_models") with FormRow(elem_id="model_upload_form_row_03"): hypernetwork_path = gr.Dropdown(label="Hypernetwork", choices=sorted(scan_hypernetwork_model()),elem_id="hyper_model_dropdown") create_refresh_button(hypernetwork_path, scan_hypernetwork_model, lambda: {"choices": sorted(scan_hypernetwork_model())}, "refresh_hyper_models") with gr.Row(): model_update_button = gr.Button(value="Upload Models to Cloud", variant="primary",elem_id="sagemaker_model_update_button", size=(200, 50)) model_update_button.click(_js="model_update", fn=sagemaker_ui.sagemaker_upload_model_s3, inputs=[sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path, controlnet_model_path], outputs=[test_connection_result, sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path, controlnet_model_path]) with gr.Blocks(title="Deploy New SageMaker Endpoint", variant='panel'): gr.HTML(value="Deploy New SageMaker Endpoint") with gr.Row(): instance_type_dropdown = gr.Dropdown(label="SageMaker Instance Type", choices=async_inference_choices, elem_id="sagemaker_inference_instance_type_textbox", value="ml.g4dn.xlarge") instance_count_dropdown = gr.Dropdown(label="Please select Instance count", choices=["1","2","3","4"], elem_id="sagemaker_inference_instance_count_textbox", value="1") with gr.Row(): sagemaker_deploy_button = gr.Button(value="Deploy", variant='primary',elem_id="sagemaker_deploy_endpoint_buttion") sagemaker_deploy_button.click(sagemaker_ui.sagemaker_deploy, _js="deploy_endpoint", \ inputs = [instance_type_dropdown, instance_count_dropdown], outputs=[test_connection_result]) with gr.Blocks(title="Delete SageMaker Endpoint", variant='panel'): gr.HTML(value="Delete SageMaker Endpoint") with gr.Row(): sagemaker_endpoint_delete_dropdown = gr.Dropdown(choices=sagemaker_ui.sagemaker_endpoints, multiselect=True, label="Select Cloud SageMaker Endpoint") modules.ui.create_refresh_button(sagemaker_endpoint_delete_dropdown, sagemaker_ui.update_sagemaker_endpoints, lambda: {"choices": sagemaker_ui.sagemaker_endpoints}, "refresh_sagemaker_endpoints_delete") sagemaker_endpoint_delete_button = gr.Button(value="Delete", variant='primary',elem_id="sagemaker_endpoint_delete_button") sagemaker_endpoint_delete_button.click(sagemaker_ui.sagemaker_endpoint_delete, _js="delete_sagemaker_endpoint", \ inputs = [sagemaker_endpoint_delete_dropdown], outputs=[test_connection_result]) with gr.Column(variant="panel", scale=1): # TODO: uncomment if implemented, comment since the tab component do not has visible parameter # with gr.Blocks(title="Deploy New SageMaker Endpoint", variant='panel', visible=False): # gr.HTML(value="AWS Model Setting", visible=False) # with gr.Tab("Select"): # gr.HTML(value="AWS Built-in Model", visible=False) # model_select_dropdown = gr.Dropdown(buildin_model_list, label="Select Built-In Model", elem_id="aws_select_model", visible=False) # with gr.Tab("Create"): # gr.HTML(value="AWS Custom Model", visible=False) # model_name_textbox = gr.Textbox(value="", lines=1, placeholder="Please enter model name", label="Model Name", visible=False) # model_create_button = gr.Button(value="Create Model", variant='primary',elem_id="aws_create_model", visible=False) with gr.Blocks(title="Create AWS dataset", variant='panel'): gr.HTML(value="AWS Dataset Management") with gr.Tab("Create"): def upload_file(files): file_paths = [file.name for file in files] return file_paths file_output = gr.File() upload_button = gr.UploadButton("Click to Upload a File", file_types=["image", "video"], file_count="multiple") upload_button.upload(fn=upload_file, inputs=[upload_button], outputs=[file_output]) def create_dataset(files, dataset_name, dataset_desc): print(dataset_name) dataset_content = [] file_path_lookup = {} for file in files: orig_name = file.name.split(os.sep)[-1] file_path_lookup[orig_name] = file.name dataset_content.append( { "filename": orig_name, "name": orig_name, "type": "image", "params": {} } ) payload = { "dataset_name": dataset_name, "content": dataset_content, "params": { "description": dataset_desc } } url = get_variable_from_json('api_gateway_url') + '/dataset' api_key = get_variable_from_json('api_token') raw_response = requests.post(url=url, json=payload, headers={'x-api-key': api_key}) raw_response.raise_for_status() response = raw_response.json() print(f"Start upload sample files response:\n{response}") for filename, presign_url in response['s3PresignUrl'].items(): file_path = file_path_lookup[filename] with open(file_path, 'rb') as f: response = requests.put(presign_url, f) print(response) response.raise_for_status() payload = { "dataset_name": dataset_name, "status": "Enabled" } raw_response = requests.put(url=url, json=payload, headers={'x-api-key': api_key}) raw_response.raise_for_status() print(raw_response.json()) return f'Complete Dataset {dataset_name} creation', None, None, None, None dataset_name_upload = gr.Textbox(value="", lines=1, placeholder="Please input dataset name", label="Dataset Name",elem_id="sd_dataset_name_textbox") dataset_description_upload = gr.Textbox(value="", lines=1, placeholder="Please input dataset description", label="Dataset Description",elem_id="sd_dataset_description_textbox") create_dataset_button = gr.Button("Create Dataset", variant="primary", elem_id="sagemaker_dataset_create_button") # size=(200, 50) dataset_create_result = gr.Textbox(value="", label="Create Result", interactive=False) create_dataset_button.click( fn=create_dataset, inputs=[upload_button, dataset_name_upload, dataset_description_upload], outputs=[ dataset_create_result, dataset_name_upload, dataset_description_upload, file_output, upload_button ], show_progress=True ) with gr.Tab('Browse'): with gr.Row(): global cloud_datasets cloud_datasets = get_sorted_cloud_dataset() cloud_dataset_name = gr.Dropdown( label="Dataset From Cloud", choices=[d['datasetName'] for d in cloud_datasets], elem_id="cloud_dataset_dropdown", type="index", info='select datasets from cloud' ) def refresh_datasets(): global cloud_datasets cloud_datasets = get_sorted_cloud_dataset() return cloud_datasets def refresh_datasets_dropdown(): global cloud_datasets cloud_datasets = get_sorted_cloud_dataset() return {"choices": [d['datasetName'] for d in cloud_datasets]} create_refresh_button( cloud_dataset_name, refresh_datasets, refresh_datasets_dropdown, "refresh_cloud_dataset", ) with gr.Row(): dataset_s3_output = gr.Textbox(label='dataset s3 location', show_label=True, type='text').style(show_copy_button=True) with gr.Row(): dataset_des_output = gr.Textbox(label='dataset description', show_label=True, type='text') with gr.Row(): dataset_gallery = gr.Gallery( label="Dataset images", show_label=False, elem_id="gallery", ).style(columns=[2], rows=[2], object_fit="contain", height="auto") def get_results_from_datasets(dataset_idx): ds = cloud_datasets[dataset_idx] url = f"{get_variable_from_json('api_gateway_url')}/dataset/{ds['datasetName']}/data" api_key = get_variable_from_json('api_token') raw_response = requests.get(url=url, headers={'x-api-key': api_key}) raw_response.raise_for_status() dataset_items = [ (Image.open(requests.get(item['preview_url'], stream=True).raw), item['key']) for item in raw_response.json()['data']] return ds['s3'], ds['description'], dataset_items cloud_dataset_name.select(fn=get_results_from_datasets, inputs=[cloud_dataset_name], outputs=[dataset_s3_output, dataset_des_output, dataset_gallery]) return (sagemaker_interface, "Amazon SageMaker", "sagemaker_interface"), script_callbacks.on_after_component(on_after_component_callback) script_callbacks.on_ui_tabs(on_ui_tabs) # create new tabs for create Model origin_callback = script_callbacks.ui_tabs_callback def avoid_duplicate_from_restart_ui(res): for extension_ui in res: if extension_ui[1] == 'Dreambooth': for key in list(extension_ui[0].blocks): val = extension_ui[0].blocks[key] if type(val) is gr.Tab: if val.label == 'Select From Cloud': return True return False def ui_tabs_callback(): res = origin_callback() if avoid_duplicate_from_restart_ui(res): return res for extension_ui in res: if extension_ui[1] == 'Dreambooth': for key in list(extension_ui[0].blocks): val = extension_ui[0].blocks[key] if type(val) is gr.Tab: if val.label == 'Select': with extension_ui[0]: with val.parent: with gr.Tab('Select From Cloud'): with gr.Row(): cloud_db_model_name = gr.Dropdown( label="Model", choices=sorted(get_cloud_db_model_name_list()), elem_id="cloud_db_model_name" ) create_refresh_button( cloud_db_model_name, get_cloud_db_model_name_list, lambda: {"choices": sorted(get_cloud_db_model_name_list())}, "refresh_db_models", ) with gr.Row(): cloud_db_snapshot = gr.Dropdown( label="Cloud Snapshot to Resume", choices=sorted(get_cloud_model_snapshots()), elem_id="cloud_snapshot_to_resume_dropdown" ) create_refresh_button( cloud_db_snapshot, get_cloud_model_snapshots, lambda: {"choices": sorted(get_cloud_model_snapshots())}, "refresh_db_snapshots", ) with gr.Row(): cloud_train_instance_type = gr.Dropdown( label="SageMaker Train Instance Type", choices=['ml.g4dn.2xlarge', 'ml.g5.2xlarge'], elem_id="cloud_train_instance_type", info='select SageMaker Train Instance Type' ) with gr.Row(visible=False) as lora_model_row: cloud_db_lora_model_name = gr.Dropdown( label="Lora Model", choices=get_sorted_lora_cloud_models(), elem_id="cloud_lora_model_dropdown" ) create_refresh_button( cloud_db_lora_model_name, get_sorted_lora_cloud_models, lambda: {"choices": get_sorted_lora_cloud_models()}, "refresh_lora_models", ) with gr.Row(): gr.HTML(value="Loaded Model from Cloud:") cloud_db_model_path = gr.HTML() with gr.Row(): gr.HTML(value="Cloud Model Revision:") cloud_db_revision = gr.HTML(elem_id="cloud_db_revision") with gr.Row(): gr.HTML(value="Cloud Model Epoch:") cloud_db_epochs = gr.HTML(elem_id="cloud_db_epochs") with gr.Row(): gr.HTML(value="V2 Model From Cloud:") cloud_db_v2 = gr.HTML(elem_id="cloud_db_v2") with gr.Row(): gr.HTML(value="Has EMA:") cloud_db_has_ema = gr.HTML(elem_id="cloud_db_has_ema") with gr.Row(): gr.HTML(value="Source Checkpoint From Cloud:") cloud_db_src = gr.HTML() with gr.Row(): gr.HTML(value="Cloud DB Status:") cloud_db_status = gr.HTML(elem_id="db_status", value="") with gr.Row(): gr.HTML(value="Experimental Shared Source:") cloud_db_shared_diffusers_path = gr.HTML() with gr.Row(): gr.HTML(value="Training Jobs Details:") with gr.Row(): training_job_dashboard = gr.Dataframe( headers=["id", "model name", "status", "SageMaker train name"], datatype=["str", "str", "str", "str"], col_count=(4, "fixed"), value=get_train_job_list, interactive=False, every=3, elem_id='training_job_dashboard' # show_progress=True ) with gr.Tab('Create From Cloud'): with gr.Column(): cloud_db_create_model = gr.Button( value="Create Model From Cloud", variant="primary" ) cloud_db_new_model_name = gr.Textbox(label="Name", placeholder="Model names can only contain alphanumeric and -") with gr.Row(): cloud_db_create_from_hub = gr.Checkbox( label="Create From Hub", value=False, visible=False ) cloud_db_512_model = gr.Checkbox(label="512x Model", value=True) with gr.Column(visible=False) as hub_row: cloud_db_new_model_url = gr.Textbox( label="Model Path", placeholder="runwayml/stable-diffusion-v1-5", elem_id="cloud_db_model_path_text_box" ) cloud_db_new_model_token = gr.Textbox( label="HuggingFace Token", value="" ) with gr.Column(visible=True) as local_row: with gr.Row(): cloud_db_new_model_src = gr.Dropdown( label="Source Checkpoint", choices=sorted(get_sd_cloud_models()), elem_id="cloud_db_source_checkpoint_dropdown" ) create_refresh_button( cloud_db_new_model_src, get_sd_cloud_models, lambda: {"choices": sorted(get_sd_cloud_models())}, "refresh_sd_models", ) with gr.Column(visible=False) as shared_row: with gr.Row(): cloud_db_new_model_shared_src = gr.Dropdown( label="EXPERIMENTAL: LoRA Shared Diffusers Source", choices=[], value="" ) cloud_db_new_model_extract_ema = gr.Checkbox( label="Extract EMA Weights", value=False ) cloud_db_train_unfrozen = gr.Checkbox(label="Unfreeze Model", value=False, elem_id="cloud_db_unfreeze_model_checkbox") with gr.Row(): gr.HTML(value="Model Creation Jobs Details:") with gr.Row(): createmodel_dashboard = gr.Dataframe( headers=["id", "model name", "status"], datatype=["str", "str", "str"], col_count=(3, "fixed"), value=get_create_model_job_list, interactive=False, every=3 # show_progress=True ) def toggle_new_rows(create_from): return gr.update(visible=create_from), gr.update(visible=not create_from) cloud_db_create_from_hub.change( fn=toggle_new_rows, inputs=[cloud_db_create_from_hub], outputs=[hub_row, local_row], ) cloud_db_model_name.change( _js="clear_loaded", fn=wrap_load_model_params, inputs=[cloud_db_model_name], outputs=[ cloud_db_model_path, cloud_db_revision, cloud_db_epochs, cloud_db_v2, cloud_db_has_ema, cloud_db_src, cloud_db_shared_diffusers_path, cloud_db_snapshot, cloud_db_lora_model_name, cloud_db_status, ], ) cloud_db_create_model.click( fn=cloud_create_model, _js="check_create_model_params", inputs=[ cloud_db_new_model_name, cloud_db_new_model_src, cloud_db_new_model_shared_src, cloud_db_create_from_hub, cloud_db_new_model_url, cloud_db_new_model_token, cloud_db_new_model_extract_ema, cloud_db_train_unfrozen, cloud_db_512_model, ], outputs=[ createmodel_dashboard # cloud_db_new_model_name # cloud_db_create_from_hub # cloud_db_512_model # cloud_db_new_model_url # cloud_db_new_model_token # cloud_db_new_model_src ] ) break return res script_callbacks.ui_tabs_callback = ui_tabs_callback def get_sorted_lora_cloud_models(): return [] def get_cloud_model_snapshots(): return []