import copy
import itertools
import os
from pathlib import Path
import html
import boto3
import time
import json
import requests
import base64
from urllib.parse import urljoin
import gradio as gr
from modules import shared, scripts
from modules.ui import create_refresh_button
from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
from utils import get_variable_from_json
from utils import upload_file_to_s3_by_presign_url, upload_multipart_files_to_s3_by_signed_url
from requests.exceptions import JSONDecodeError
from datetime import datetime
import math
inference_job_dropdown = None
sagemaker_endpoint = None
primary_model_name = None
secondary_model_name = None
tertiary_model_name = None
#TODO: convert to dynamically init the following variables
sagemaker_endpoints = []
txt2img_inference_job_ids = []
sd_checkpoints = []
textual_inversion_list = []
lora_list = []
hyperNetwork_list = []
ControlNet_model_list = []
# Initial checkpoints information
checkpoint_info = {}
checkpoint_type = ["Stable-diffusion", "embeddings", "Lora", "hypernetworks", "ControlNet"]
checkpoint_name = ["stable_diffusion", "embeddings", "lora", "hypernetworks", "controlnet"]
stable_diffusion_list = []
embeddings_list = []
lora_list = []
hypernetworks_list = []
controlnet_list = []
for ckpt_type, ckpt_name in zip(checkpoint_type, checkpoint_name):
checkpoint_info[ckpt_type] = {}
# get api_gateway_url
api_gateway_url = get_variable_from_json('api_gateway_url')
api_key = get_variable_from_json('api_token')
def plaintext_to_html(text):
text = "
" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "
"
return text
def get_s3_file_names(bucket, folder):
"""Get a list of file names from an S3 bucket and folder."""
s3 = boto3.resource('s3')
bucket = s3.Bucket(bucket)
objects = bucket.objects.filter(Prefix=folder)
names = [obj.key for obj in objects]
return names
def get_current_date():
today = datetime.today()
formatted_date = today.strftime('%Y-%m-%d')
return formatted_date
def server_request(path):
api_gateway_url = get_variable_from_json('api_gateway_url')
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
api_key = get_variable_from_json('api_token')
headers = {
"x-api-key": api_key,
"Content-Type": "application/json"
}
list_endpoint_url = f'{api_gateway_url}{path}'
response = requests.get(list_endpoint_url, headers=headers)
return response
def datetime_to_short_form(datetime_str):
dt = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S.%f")
short_form = dt.strftime("%Y-%m-%d-%H-%M-%S")
return short_form
def update_sagemaker_endpoints():
global sagemaker_endpoints
sagemaker_endpoints.clear()
try:
response = server_request('inference/list-endpoint-deployment-jobs')
r = response.json()
if r:
sagemaker_endpoints.clear() # Clear the existing list before appending new values
sagemaker_raw_endpoints = []
for obj in r:
if "EndpointDeploymentJobId" in obj :
if "endpoint_name" in obj:
endpoint_name = obj["endpoint_name"]
endpoint_status = obj["endpoint_status"]
else:
endpoint_name = obj["EndpointDeploymentJobId"]
endpoint_status = obj["status"]
# Skip if status is 'deleted'
if endpoint_status == 'deleted':
continue
if "endTime" in obj:
endpoint_time = obj["endTime"]
else:
endpoint_time = "N/A"
endpoint_info = f"{endpoint_name}+{endpoint_status}+{endpoint_time}"
sagemaker_raw_endpoints.append(endpoint_info)
# Sort the list based on completeTime in descending order
sagemaker_endpoints= sorted(sagemaker_raw_endpoints, key=lambda x: x.split('+')[-1], reverse=True)
else:
print("The API response is empty for update_sagemaker_endpoints().")
except Exception as e:
print(f"An error occurred while updating SageMaker endpoints: {e}")
def update_txt2img_inference_job_ids():
global txt2img_inference_job_ids
get_inference_job_list()
def origin_update_txt2img_inference_job_ids():
global origin_txt2img_inference_job_ids
def get_inference_job_list():
global txt2img_inference_job_ids
try:
txt2img_inference_job_ids.clear() # Clear the existing list before appending new values
response = server_request('inference/list-inference-jobs')
r = response.json()
if r:
txt2img_inference_job_ids.clear() # Clear the existing list before appending new values
temp_list = []
for obj in r:
if obj.get('completeTime') is None:
complete_time = obj.get('startTime')
else:
complete_time = obj.get('completeTime')
status = obj.get('status')
inference_job_id = obj.get('InferenceJobId')
combined_string = f"{complete_time}-->{status}-->{inference_job_id}"
temp_list.append((complete_time, combined_string))
# Sort the list based on completeTime in descending order
sorted_list = sorted(temp_list, key=lambda x: x[0], reverse=True)
# Append the sorted combined strings to the txt2img_inference_job_ids list
for item in sorted_list:
txt2img_inference_job_ids.append(item[1])
else:
print("The API response is empty.")
except Exception as e:
print("Exception occurred when fetching inference_job_ids")
def get_inference_job(inference_job_id):
response = server_request(f'inference/get-inference-job?jobID={inference_job_id}')
print(f"response of get_inference_job is {str(response)}")
return response.json()
def get_inference_job_image_output(inference_job_id):
try:
response = server_request(f'inference/get-inference-job-image-output?jobID={inference_job_id}')
r = response.json()
txt2img_inference_job_image_list = []
for obj in r:
obj_value = str(obj)
txt2img_inference_job_image_list.append(obj_value)
return txt2img_inference_job_image_list
except Exception as e:
print(f"An error occurred while getting inference job image output: {e}")
return []
def get_inference_job_param_output(inference_job_id):
try:
response = server_request(f'inference/get-inference-job-param-output?jobID={inference_job_id}')
r = response.json()
txt2img_inference_job_param_list = []
for obj in r:
obj_value = str(obj)
txt2img_inference_job_param_list.append(obj_value)
return txt2img_inference_job_param_list
except Exception as e:
print(f"An error occurred while getting inference job param output: {e}")
return []
def download_images(image_urls: list, local_directory: str):
if not os.path.exists(local_directory):
os.makedirs(local_directory)
image_list = []
for url in image_urls:
try:
response = requests.get(url)
response.raise_for_status()
image_name = os.path.basename(url).split('?')[0]
local_path = os.path.join(local_directory, image_name)
with open(local_path, 'wb') as f:
f.write(response.content)
image_list.append(local_path)
except requests.exceptions.RequestException as e:
print(f"Error downloading image {url}: {e}")
return image_list
def get_model_list_by_type(model_type):
api_gateway_url = get_variable_from_json('api_gateway_url')
api_key = get_variable_from_json('api_token')
# check if api_gateway_url and api_key are set
if api_gateway_url is None or api_key is None:
print("api_gateway_url or api_key is not set")
return []
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
url = api_gateway_url + f"checkpoints?status=Active"
if isinstance(model_type, list):
url += "&types=" + "&types=".join(model_type)
else:
url += f"&types={model_type}"
try:
response = requests.get(url=url, headers={'x-api-key': api_key})
response.raise_for_status()
json_response = response.json()
if "checkpoints" not in json_response.keys():
return []
checkpoint_list = []
for ckpt in json_response["checkpoints"]:
if "name" not in ckpt:
continue
if ckpt["name"] is None:
continue
ckpt_type = ckpt["type"]
for ckpt_name in ckpt["name"]:
ckpt_s3_pos = f"{ckpt['s3Location']}/{ckpt_name}"
checkpoint_info[ckpt_type][ckpt_name] = ckpt_s3_pos
checkpoint_list.append(ckpt_name)
return checkpoint_list
except Exception as e:
print(f"Error fetching model list: {e}")
return []
def update_sd_checkpoints():
model_type = ["Stable-diffusion"]
return get_model_list_by_type(model_type)
def get_texual_inversion_list():
model_type = "embeddings"
return get_model_list_by_type(model_type)
def get_lora_list():
model_type = "Lora"
return get_model_list_by_type(model_type)
def get_hypernetwork_list():
model_type = "hypernetworks"
return get_model_list_by_type(model_type)
def get_controlnet_model_list():
model_type = "ControlNet"
return get_model_list_by_type(model_type)
def refresh_all_models():
api_gateway_url = get_variable_from_json('api_gateway_url')
api_key = get_variable_from_json('api_token')
try:
for rp, name in zip(checkpoint_type, checkpoint_name):
url = api_gateway_url + f"checkpoints?status=Active&types={rp}"
response = requests.get(url=url, headers={'x-api-key': api_key})
json_response = response.json()
# print(f"response url json for model {rp} is {json_response}")
checkpoint_info[rp] = {}
if "checkpoints" not in json_response.keys():
continue
for ckpt in json_response["checkpoints"]:
if "name" not in ckpt:
continue
if ckpt["name"] is None:
continue
ckpt_type = ckpt["type"]
checkpoint_info[ckpt_type] = {}
for ckpt_name in ckpt["name"]:
ckpt_s3_pos = f"{ckpt['s3Location']}/{ckpt_name.split('/')[-1]}"
checkpoint_info[ckpt_type][ckpt_name] = ckpt_s3_pos
except Exception as e:
print(f"Error refresh all models: {e}")
def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path, controlnet_model_path):
log = "start upload model to s3..."
local_paths = [sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path, controlnet_model_path]
print(f"Refresh checkpionts before upload to get rid of duplicate uploads...")
refresh_all_models()
for lp, rp in zip(local_paths, checkpoint_type):
if lp == "" or not lp:
continue
print(f"lp is {lp}")
model_name = lp.split("/")[-1]
exist_model_list = list(checkpoint_info[rp].keys())
if model_name in exist_model_list:
print(f"!!!skip to upload duplicate model {model_name}")
continue
part_size = 1000 * 1024 * 1024
file_size = os.stat(lp)
parts_number = math.ceil(file_size.st_size/part_size)
print('!!!!!!!!!!', file_size, parts_number)
#local_tar_path = f'{model_name}.tar'
local_tar_path = model_name
payload = {
"checkpoint_type": rp,
"filenames": [{
"filename": local_tar_path,
"parts_number": parts_number
}],
"params": {"message": "placeholder for chkpts upload test"}
}
api_gateway_url = get_variable_from_json('api_gateway_url')
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
api_key = get_variable_from_json('api_token')
print('!!!!!!api_gateway_url', api_gateway_url)
url = str(api_gateway_url) + "checkpoint"
print(f"Post request for upload s3 presign url: {url}")
response = requests.post(url=url, json=payload, headers={'x-api-key': api_key})
try:
json_response = response.json()
# print(f"Response json {json_response}")
s3_base = json_response["checkpoint"]["s3_location"]
checkpoint_id = json_response["checkpoint"]["id"]
print(f"Upload to S3 {s3_base}")
print(f"Checkpoint ID: {checkpoint_id}")
#s3_presigned_url = json_response["s3PresignUrl"][model_name]
s3_signed_urls_resp = json_response["s3PresignUrl"][local_tar_path]
# Upload src model to S3.
if rp != "embeddings" :
local_model_path_in_repo = f'models/{rp}/{model_name}'
else:
local_model_path_in_repo = f'{rp}/{model_name}'
#local_tar_path = f'{model_name}.tar'
print("Pack the model file.")
os.system(f"cp -f {lp} {local_model_path_in_repo}")
if rp == "Stable-diffusion":
model_yaml_name = model_name.split('.')[0] + ".yaml"
local_model_yaml_path = "/".join(lp.split("/")[:-1]) + f"/{model_yaml_name}"
local_model_yaml_path_in_repo = f"models/{rp}/{model_yaml_name}"
if os.path.isfile(local_model_yaml_path):
os.system(f"cp -f {local_model_yaml_path} {local_model_yaml_path_in_repo}")
os.system(f"tar cvf {local_tar_path} {local_model_path_in_repo} {local_model_yaml_path_in_repo}")
else:
os.system(f"tar cvf {local_tar_path} {local_model_path_in_repo}")
else:
os.system(f"tar cvf {local_tar_path} {local_model_path_in_repo}")
#upload_file_to_s3_by_presign_url(local_tar_path, s3_presigned_url)
multiparts_tags = upload_multipart_files_to_s3_by_signed_url(
local_tar_path,
s3_signed_urls_resp,
part_size
)
payload = {
"checkpoint_id": checkpoint_id,
"status": "Active",
"multi_parts_tags": {local_tar_path: multiparts_tags}
}
# Start creating model on cloud.
response = requests.put(url=url, json=payload, headers={'x-api-key': api_key})
s3_input_path = s3_base
print(response)
log = f"\n finish upload {local_tar_path} to {s3_base}"
os.system(f"rm {local_tar_path}")
except Exception as e:
print(f"fail to upload model {lp}, error: {e}")
print(f"Refresh checkpionts after upload...")
refresh_all_models()
return plaintext_to_html(log), None, None, None, None, None
def generate_on_cloud(sagemaker_endpoint):
print(f"checkpiont_info {checkpoint_info}")
print(f"sagemaker endpoint {sagemaker_endpoint}")
text = "failed to check endpoint"
return plaintext_to_html(text)
def call_txt2img_inference(sagemaker_endpoint):
return call_remote_inference(sagemaker_endpoint, 'txt2img')
def call_img2img_inference(sagemaker_endpoint, **args):
return call_remote_inference(sagemaker_endpoint, 'img2img')
def call_interrogate_clip(sagemaker_endpoint, **args):
return call_remote_inference(sagemaker_endpoint, 'interrogate_clip')
def call_interrogate_deepbooru(sagemaker_endpoint, **args):
return call_remote_inference(sagemaker_endpoint, 'interrogate_deepbooru')
def call_remote_inference(sagemaker_endpoint, type):
print(f"chosen ep {sagemaker_endpoint}")
print(f"inference type is {type}")
if sagemaker_endpoint == '':
image_list = [] # Return an empty list if selected_value is None
info_text = ''
infotexts = "Failed! Please choose the endpoint in 'InService' states "
return image_list, info_text, plaintext_to_html(infotexts)
elif sagemaker_endpoint == 'FAILURE':
image_list = [] # Return an empty list if selected_value is None
info_text = ''
infotexts = "Failed upload the config to cloud "
return image_list, info_text, plaintext_to_html(infotexts)
sagemaker_endpoint_status = sagemaker_endpoint.split("+")[1]
if sagemaker_endpoint_status != "InService":
image_list = [] # Return an empty list if selected_value is None
info_text = ''
infotexts = "Failed! Please choose the endpoint in 'InService' states "
return image_list, info_text, plaintext_to_html(infotexts)
api_gateway_url = get_variable_from_json('api_gateway_url')
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
api_key = get_variable_from_json('api_token')
# stage 2: inference using endpoint_name
headers = {
"x-api-key": api_key,
"Content-Type": "application/json"
}
checkpoint_info['sagemaker_endpoint'] = sagemaker_endpoint.split("+")[0]
payload = checkpoint_info
payload['task_type'] = type
print(f"checkpointinfo is {payload}")
inference_url = f"{api_gateway_url}inference/run-sagemaker-inference"
response = requests.post(inference_url, json=payload, headers=headers)
print(f"Raw server response: {response.text}")
try:
r = response.json()
except JSONDecodeError as e:
print(f"Failed to decode JSON response: {e}")
print(f"Raw server response: {response.text}")
else:
print(f"response for rest api {r}")
inference_id = r.get('inference_id') # Assuming the response contains 'inference_id' field
print(f"inference_id is {inference_id}")
image_list = [] # Return an empty list if selected_value is None
info_text = ''
infotexts = f"Inference id is {inference_id}, please go to inference job Id dropdown to check the status"
return image_list, info_text, plaintext_to_html(infotexts)
# TODO: temp comment the while loop since it will block user to click inference
# # Loop until the get_inference_job status is 'succeed' or 'failed'
# max_attempts = 10
# attempt_count = 0
# while attempt_count < max_attempts:
# job_status = get_inference_job(inference_id)
# status = job_status['status']
# if status == 'succeed':
# break
# elif status == 'failure':
# print(f"Inference job failed: {job_status.get('error', 'No error message provided')}")
# break
# time.sleep(3) # You can adjust the sleep time as needed
# attempt_count += 1
# if status == 'succeed':
# return display_inference_result(inference_id)
# elif status == 'failure':
# image_list = [] # Return an empty list if selected_value is None
# info_text = ''
# infotexts = f"Inference Failed! The error info: {job_status.get('error', 'No error message provided')}"
# return image_list, info_text, plaintext_to_html(infotexts)
# else:
# image_list = [] # Return an empty list if selected_value is None
# info_text = ''
# infotexts = f"Inference time is longer than 30 seconds, please go to inference job Id dropdown to check the status"
# return image_list, info_text, plaintext_to_html(infotexts)
def sagemaker_endpoint_delete(delete_endpoint_list):
print(f"start delete sagemaker endpoint delete function")
print(f"delete endpoint list: {delete_endpoint_list}")
api_gateway_url = get_variable_from_json('api_gateway_url')
api_key = get_variable_from_json('api_token')
delete_endpoint_list = [item.split('+')[0] for item in delete_endpoint_list]
print(f"delete endpoint list: {delete_endpoint_list}")
# check if api_gateway_url and api_key are set
if api_gateway_url is None or api_key is None:
print("api_gateway_url and api_key are not set")
return
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
payload = {
"delete_endpoint_list": delete_endpoint_list,
}
deployment_url = f"{api_gateway_url}inference/delete-sagemaker-endpoint"
headers = {
"x-api-key": api_key,
"Content-Type": "application/json"
}
try:
response = requests.post(deployment_url, json=payload, headers=headers)
r = response.json()
print(f"response for rest api {r}")
return "Endpoint delete completed"
except Exception as e:
return f"Failed to delete sagemaker endpoint with exception: {e}"
def sagemaker_deploy(instance_type, initial_instance_count=1):
""" Create SageMaker endpoint for GPU inference.
Args:
instance_type (string): the ML compute instance type.
initial_instance_count (integer): Number of instances to launch initially.
Returns:
(None)
"""
# function code to call sagemaker deploy api
print(f"start deploying instance type: {instance_type} with count {initial_instance_count}............")
api_gateway_url = get_variable_from_json('api_gateway_url')
api_key = get_variable_from_json('api_token')
# check if api_gateway_url and api_key are set
if api_gateway_url is None or api_key is None:
print("api_gateway_url and api_key are not set")
return
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
payload = {
"instance_type": instance_type,
"initial_instance_count": initial_instance_count
}
deployment_url = f"{api_gateway_url}inference/deploy-sagemaker-endpoint"
headers = {
"x-api-key": api_key,
"Content-Type": "application/json"
}
try:
response = requests.post(deployment_url, json=payload, headers=headers)
r = response.json()
print(f"response for rest api {r}")
return "Endpoint deployment started"
except Exception as e:
return f"Failed to start endpoint deployment with exception: {e}"
def modelmerger_on_cloud_func(primary_model_name, secondary_model_name, teritary_model_name):
print(f"function under development, current checkpoint_info is {checkpoint_info}")
api_gateway_url = get_variable_from_json('api_gateway_url')
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
api_key = get_variable_from_json('api_token')
if api_gateway_url is None:
print(f"modelmerger: failed to get the api-gateway url, can not fetch remote data")
return []
modelmerge_url = f"{api_gateway_url}inference/run-model-merge"
payload = {
"primary_model_name" : primary_model_name,
"secondary_model_name" : secondary_model_name,
"tertiary_model_name" : teritary_model_name
}
headers = {
"x-api-key": api_key,
"Content-Type": "application/json"
}
response = requests.post(modelmerge_url, json=payload, headers=headers)
try:
r = response.json()
except JSONDecodeError as e:
print(f"Failed to decode JSON response: {e}")
print(f"Raw server response: {response.text}")
else:
print(f"response for rest api {r}")
def txt2img_config_save():
# placeholder for saving txt2img config
pass
def displayEndpointInfo(input_string: str):
print(f"selected value is {input_string}")
parts = input_string.split('+')
if len(parts) < 2:
return plaintext_to_html("")
endpoint_job_id, status = parts[0], parts[1]
if status == 'failed':
response = server_request(f'inference/get-endpoint-deployment-job?jobID={endpoint_job_id}')
# Do something with the response
r = response.json()
if "error" in r:
return plaintext_to_html(r["error"])
else:
return plaintext_to_html(r["EndpointDeploymentJobId"])
else:
return plaintext_to_html("")
def fake_gan(selected_value: str ):
print(f"selected value is {selected_value}")
if selected_value is not None:
delimiter = "-->"
parts = selected_value.split(delimiter)
# Extract the InferenceJobId value
inference_job_id = parts[2].strip()
inference_job_status = parts[1].strip()
if inference_job_status == 'inprogress':
return [], [], plaintext_to_html('inference still in progress')
images = get_inference_job_image_output(inference_job_id)
image_list = []
image_list = download_images(images,f"outputs/txt2img-images/{get_current_date()}/{inference_job_id}/")
inference_pram_json_list = get_inference_job_param_output(inference_job_id)
json_list = []
json_list = download_images(inference_pram_json_list, f"outputs/txt2img-images/{get_current_date()}/{inference_job_id}/")
print(f"{str(images)}")
print(f"{str(inference_pram_json_list)}")
json_file = f"outputs/txt2img-images/{get_current_date()}/{inference_job_id}/{inference_job_id}_param.json"
f = open(json_file)
log_file = json.load(f)
info_text = log_file["info"]
infotexts = json.loads(info_text)["infotexts"][0]
else:
image_list = [] # Return an empty list if selected_value is None
json_list = []
info_text = ''
return image_list, info_text, plaintext_to_html(infotexts)
def display_inference_result(inference_id: str ):
print(f"selected value is {inference_id}")
if inference_id is not None:
# Extract the InferenceJobId value
inference_job_id = inference_id
images = get_inference_job_image_output(inference_job_id)
image_list = []
image_list = download_images(images,f"outputs/txt2img-images/{get_current_date()}/{inference_job_id}/")
inference_pram_json_list = get_inference_job_param_output(inference_job_id)
json_list = []
json_list = download_images(inference_pram_json_list, f"outputs/txt2img-images/{get_current_date()}/{inference_job_id}/")
print(f"{str(images)}")
print(f"{str(inference_pram_json_list)}")
json_file = f"outputs/txt2img-images/{get_current_date()}/{inference_job_id}/{inference_job_id}_param.json"
f = open(json_file)
log_file = json.load(f)
info_text = log_file["info"]
infotexts = json.loads(info_text)["infotexts"][0]
else:
image_list = [] # Return an empty list if selected_value is None
json_list = []
info_text = ''
return image_list, info_text, plaintext_to_html(infotexts)
def init_refresh_resource_list_from_cloud():
print(f"start refreshing resource list from cloud")
if get_variable_from_json('api_gateway_url') is not None:
update_sagemaker_endpoints()
refresh_all_models()
get_texual_inversion_list()
get_lora_list()
get_hypernetwork_list()
get_controlnet_model_list()
get_inference_job_list()
else:
print(f"there is no api-gateway url and token in local file,")
def create_ui(is_img2img):
global txt2img_gallery, txt2img_generation_info
import modules.ui
init_refresh_resource_list_from_cloud()
with gr.Group():
with gr.Accordion("Amazon SageMaker Inference", open=False):
sagemaker_html_log = gr.HTML(elem_id=f'html_log_sagemaker')
with gr.Column(variant='panel'):
with gr.Row():
global sagemaker_endpoint
sagemaker_endpoint = gr.Dropdown(sagemaker_endpoints,
label="Select Cloud SageMaker Endpoint",
elem_id="sagemaker_endpoint_dropdown"
)
modules.ui.create_refresh_button(sagemaker_endpoint, update_sagemaker_endpoints, lambda: {"choices": sagemaker_endpoints}, "refresh_sagemaker_endpoints")
with gr.Row():
sd_checkpoint = gr.Dropdown(multiselect=True, label="Stable Diffusion Checkpoint", choices=sorted(update_sd_checkpoints()), elem_id="stable_diffusion_checkpoint_dropdown")
sd_checkpoint_refresh_button = modules.ui.create_refresh_button(sd_checkpoint, update_sd_checkpoints, lambda: {"choices": sorted(update_sd_checkpoints())}, "refresh_sd_checkpoints")
with gr.Column():
global generate_on_cloud_button_with_js
if not is_img2img:
generate_on_cloud_button_with_js = gr.Button(value="Generate on Cloud", variant='primary', elem_id="generate_on_cloud_with_cloud_config_button",queue=True, show_progress=True)
global generate_on_cloud_button_with_js_img2img
global interrogate_clip_on_cloud_button
global interrogate_deep_booru_on_cloud_button
if is_img2img:
with gr.Row():
with gr.Column():
interrogate_clip_on_cloud_button = gr.Button(value="Interrogate CLIP", elem_id="interrogate_clip_on_cloud_button")
with gr.Column():
interrogate_deep_booru_on_cloud_button = gr.Button(value="Interrogte DeepBooru", elem_id="interrogate_deep_booru_on_cloud_button")
with gr.Column():
generate_on_cloud_button_with_js_img2img = gr.Button(value="Generate on Cloud img2img", variant='primary', elem_id="generate_on_cloud_with_cloud_config_button_img2img",queue=True, show_progress=True)
with gr.Row():
global inference_job_dropdown
global txt2img_inference_job_ids
inference_job_dropdown = gr.Dropdown(txt2img_inference_job_ids,
label="Inference Job IDs",
elem_id="txt2img_inference_job_ids_dropdown"
)
txt2img_inference_job_ids_refresh_button = modules.ui.create_refresh_button(inference_job_dropdown, update_txt2img_inference_job_ids, lambda: {"choices": txt2img_inference_job_ids}, "refresh_txt2img_inference_job_ids")
with gr.Row():
gr.HTML(value="Extra Networks for Cloud Inference")
with gr.Row():
textual_inversion_dropdown = gr.Dropdown(multiselect=True, label="Textual Inversion", choices=sorted(get_texual_inversion_list()),elem_id="sagemaker_texual_inversion_dropdown")
create_refresh_button(
textual_inversion_dropdown,
get_texual_inversion_list,
lambda: {"choices": sorted(get_texual_inversion_list())},
"refresh_textual_inversion",
)
lora_dropdown = gr.Dropdown(lora_list, multiselect=True, label="LoRA", elem_id="sagemaker_lora_list_dropdown")
create_refresh_button(
lora_dropdown,
get_lora_list,
lambda: {"choices": sorted(get_lora_list())},
"refresh_lora",
)
with gr.Row():
hyperNetwork_dropdown = gr.Dropdown(multiselect=True, label="HyperNetwork", choices=sorted(get_hypernetwork_list()), elem_id="sagemaker_hypernetwork_dropdown")
create_refresh_button(
hyperNetwork_dropdown,
get_hypernetwork_list,
lambda: {"choices": sorted(get_hypernetwork_list())},
"refresh_hypernetworks",
)
controlnet_dropdown = gr.Dropdown(multiselect=True, label="ControlNet-Model", choices=sorted(get_controlnet_model_list()), elem_id="sagemaker_controlnet_model_dropdown")
create_refresh_button(
controlnet_dropdown,
get_controlnet_model_list,
lambda: {"choices": sorted(get_controlnet_model_list())},
"refresh_controlnet",
)
with gr.Group():
with gr.Accordion("Open for Checkpoint Merge in the Cloud!", visible=False, open=False):
sagemaker_html_log = gr.HTML(elem_id=f'html_log_sagemaker')
with FormRow(elem_id="modelmerger_models_in_the_cloud"):
global primary_model_name
primary_model_name = gr.Dropdown(choices=sorted(update_sd_checkpoints()), elem_id="modelmerger_primary_model_name_in_the_cloud", label="Primary model (A) in the cloud")
create_refresh_button(primary_model_name, update_sd_checkpoints, lambda: {"choices": sorted(update_sd_checkpoints())}, "refresh_checkpoint_A_in_the_cloud")
global secondary_model_name
secondary_model_name = gr.Dropdown(choices=sorted(update_sd_checkpoints()), elem_id="modelmerger_secondary_model_name_in_the_cloud", label="Secondary model (B) in the cloud")
create_refresh_button(secondary_model_name, update_sd_checkpoints, lambda: {"choices": sorted(update_sd_checkpoints())}, "refresh_checkpoint_B_in_the_cloud")
global tertiary_model_name
tertiary_model_name = gr.Dropdown(choices=sorted(update_sd_checkpoints()), elem_id="modelmerger_tertiary_model_name_in_the_cloud", label="Tertiary model (C) in the cloud")
create_refresh_button(tertiary_model_name, update_sd_checkpoints, lambda: {"choices": sorted(update_sd_checkpoints())}, "refresh_checkpoint_C_in_the_cloud")
with gr.Row():
global modelmerger_merge_on_cloud
modelmerger_merge_on_cloud = gr.Button(elem_id="modelmerger_merge_in_the_cloud", value="Merge on Cloud", variant='primary')
return sagemaker_endpoint, sd_checkpoint, sd_checkpoint_refresh_button, textual_inversion_dropdown, lora_dropdown, hyperNetwork_dropdown, controlnet_dropdown, inference_job_dropdown, txt2img_inference_job_ids_refresh_button, primary_model_name, secondary_model_name, tertiary_model_name, modelmerger_merge_on_cloud