import logging
import os
import html
import boto3
import time
import json
import gradio
import requests
import base64
import gradio as gr
import utils
from aws_extension.auth_service.simple_cloud_auth import cloud_auth_manager
from aws_extension.cloud_api_manager.api_manager import api_manager
from aws_extension.sagemaker_ui_utils import create_refresh_button_by_user
from modules.shared import opts
from modules.ui_components import FormRow
from utils import get_variable_from_json, upload_multipart_files_to_s3_by_signed_url
from requests.exceptions import JSONDecodeError
from datetime import datetime
import math
import re
import asyncio
import nest_asyncio
from utils import cp, tar, rm
logger = logging.getLogger(__name__)
logger.setLevel(utils.LOGGING_LEVEL)
None_Option_For_On_Cloud_Model = "don't use on cloud inference"
inference_job_dropdown = None
textual_inversion_dropdown = None
hyperNetwork_dropdown = None
lora_dropdown = None
# sagemaker_endpoint = None
modelmerger_merge_on_cloud = None
interrogate_clip_on_cloud_button = None
interrogate_deep_booru_on_cloud_button = None
primary_model_name = None
secondary_model_name = None
tertiary_model_name = None
# TODO: convert to dynamically init the following variables
txt2img_inference_job_ids = []
sd_checkpoints = []
textual_inversion_list = []
lora_list = []
hyperNetwork_list = []
ControlNet_model_list = []
# Initial checkpoints information
checkpoint_info = {}
checkpoint_type = ["Stable-diffusion", "embeddings", "Lora", "hypernetworks", "ControlNet", "VAE"]
checkpoint_name = ["stable_diffusion", "embeddings", "lora", "hypernetworks", "controlnet", "VAE"]
stable_diffusion_list = []
embeddings_list = []
hypernetworks_list = []
controlnet_list = []
for ckpt_type, ckpt_name in zip(checkpoint_type, checkpoint_name):
checkpoint_info[ckpt_type] = {}
# get api_gateway_url
api_gateway_url = get_variable_from_json('api_gateway_url')
api_key = get_variable_from_json('api_token')
start_time_picker_img_value = None
end_time_picker_img_value = None
start_time_picker_txt_value = None
end_time_picker_txt_value = None
txt_task_type = None
txt_status = None
txt_endpoint = None
txt_checkpoint = None
img_task_type = None
img_status = None
img_endpoint = None
img_checkpoint = None
show_all_inference_job = False
modelTypeMap = {
'SD Checkpoints': 'Stable-diffusion',
'Textual Inversion': 'embeddings',
'LoRA model': 'Lora',
'ControlNet model': 'ControlNet',
'Hypernetwork': 'hypernetworks',
'VAE': 'VAE'
}
def plaintext_to_html(text):
text = "
" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "
"
return text
def server_request_post(path, params):
api_gateway_url = get_variable_from_json('api_gateway_url')
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
api_key = get_variable_from_json('api_token')
headers = {
"x-api-key": api_key,
"Content-Type": "application/json"
}
list_endpoint_url = f'{api_gateway_url}{path}'
response = requests.post(list_endpoint_url, json=params, headers=headers)
return response
def get_s3_file_names(bucket, folder):
"""Get a list of file names from an S3 bucket and folder."""
s3 = boto3.resource('s3')
bucket = s3.Bucket(bucket)
objects = bucket.objects.filter(Prefix=folder)
names = [obj.key for obj in objects]
return names
def get_current_date():
today = datetime.today()
formatted_date = today.strftime('%Y-%m-%d')
return formatted_date
def server_request(path):
api_gateway_url = get_variable_from_json('api_gateway_url')
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
api_key = get_variable_from_json('api_token')
headers = {
"x-api-key": api_key,
"Content-Type": "application/json"
}
list_endpoint_url = f'{api_gateway_url}{path}'
response = requests.get(list_endpoint_url, headers=headers)
return response
def datetime_to_short_form(datetime_str):
dt = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S.%f")
short_form = dt.strftime("%Y-%m-%d-%H-%M-%S")
return short_form
def query_page_inference_job_list(task_type: str, status: str, endpoint: str, checkpoint: str, is_img2img: bool,
show_all: bool):
global show_all_inference_job
show_all_inference_job = show_all
if is_img2img:
global img_task_type
img_task_type = task_type
global img_status
img_status = status
global img_endpoint
img_endpoint = endpoint
global img_checkpoint
img_checkpoint = checkpoint
opt_type = 'img2img'
logger.debug(f"{opt_type} {img_task_type} {img_status} {img_endpoint} {img_checkpoint} {show_all}")
return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type)
else:
global txt_task_type
txt_task_type = task_type
global txt_status
txt_status = txt_status
global txt_endpoint
txt_endpoint = endpoint
global txt_checkpoint
txt_checkpoint = checkpoint
opt_type = 'txt2img'
logger.debug(f"{opt_type} {txt_task_type} {txt_status} {txt_endpoint} {txt_checkpoint} {show_all}")
return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type)
def query_img_inference_job_list(task_type: str, status: str, endpoint: str, checkpoint: str):
opt_type = 'img2img'
global img_task_type
img_task_type = task_type
global img_status
img_status = status
global img_endpoint
img_endpoint = endpoint
global img_checkpoint
img_checkpoint = checkpoint
return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type)
def query_txt_inference_job_list(task_type: str, status: str, endpoint: str, checkpoint: str):
opt_type = 'txt2img'
global txt_task_type
txt_task_type = task_type
global txt_status
txt_status = status
global txt_endpoint
txt_endpoint = endpoint
global txt_checkpoint
txt_checkpoint = checkpoint
return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type)
def query_inference_job_list(task_type: str = '', status: str = '',
endpoint: str = '', checkpoint: str = '', opt_type: str = ''):
logger.debug(
f"query_inference_job_list start!!{status},{task_type},{endpoint},{checkpoint},{start_time_picker_txt_value},{end_time_picker_txt_value} {show_all_inference_job}")
try:
body_params = {}
if status:
body_params['status'] = status
if task_type:
body_params['task_type'] = task_type
if opt_type == 'txt2img':
if start_time_picker_txt_value:
body_params['start_time'] = start_time_picker_txt_value
if end_time_picker_txt_value:
body_params['end_time'] = end_time_picker_txt_value
elif opt_type == 'img2img':
if start_time_picker_img_value:
body_params['start_time'] = start_time_picker_img_value
if end_time_picker_img_value:
body_params['end_time'] = end_time_picker_img_value
if endpoint:
endpoint_name_array = endpoint.split("+")
if len(endpoint_name_array) > 0:
body_params['endpoint'] = endpoint_name_array[0]
if checkpoint:
body_params['checkpoint'] = checkpoint
body_params['limit'] = -1 if show_all_inference_job else 10
response = server_request_post(f'inference/query-inference-jobs', body_params)
r = response.json()
logger.debug(r)
if r:
txt2img_inference_job_ids.clear() # Clear the existing list before appending new values
temp_list = []
for obj in r:
if obj.get('completeTime') is None:
complete_time = obj.get('startTime')
else:
complete_time = obj.get('completeTime')
status = obj.get('status')
task_type = obj.get('taskType', 'txt2img')
inference_job_id = obj.get('InferenceJobId')
combined_string = f"{complete_time}-->{task_type}-->{status}-->{inference_job_id}"
temp_list.append((complete_time, combined_string))
# Sort the list based on completeTime in ascending order
sorted_list = sorted(temp_list, key=lambda x: x[0], reverse=False)
# Append the sorted combined strings to the txt2img_inference_job_ids list
for item in sorted_list:
txt2img_inference_job_ids.append(item[1])
# inference_job_dropdown.update(choices=txt2img_inference_job_ids)
return gr.Dropdown.update(choices=txt2img_inference_job_ids)
else:
logger.info("The API response is empty.")
return gr.Dropdown.update(choices=[])
except Exception as e:
logger.error("Exception occurred when fetching inference_job_ids")
logger.error(e)
return gr.Dropdown.update(choices=[])
def get_inference_job(inference_job_id):
response = server_request(f'inference/get-inference-job?jobID={inference_job_id}')
return response.json()
def get_inference_job_image_output(inference_job_id):
try:
response = server_request(f'inference/get-inference-job-image-output?jobID={inference_job_id}')
r = response.json()
txt2img_inference_job_image_list = []
for obj in r:
obj_value = str(obj)
txt2img_inference_job_image_list.append(obj_value)
return txt2img_inference_job_image_list
except Exception as e:
logger.error(f"An error occurred while getting inference job image output: {e}")
return []
def get_inference_job_param_output(inference_job_id):
try:
response = server_request(f'inference/get-inference-job-param-output?jobID={inference_job_id}')
r = response.json()
txt2img_inference_job_param_list = []
for obj in r:
obj_value = str(obj)
txt2img_inference_job_param_list.append(obj_value)
return txt2img_inference_job_param_list
except Exception as e:
logger.error(f"An error occurred while getting inference job param output: {e}")
return []
def download_images(image_urls: list, local_directory: str):
if not os.path.exists(local_directory):
os.makedirs(local_directory)
image_list = []
for url in image_urls:
try:
response = requests.get(url)
response.raise_for_status()
image_name = os.path.basename(url).split('?')[0]
local_path = os.path.join(local_directory, image_name)
with open(local_path, 'wb') as f:
f.write(response.content)
image_list.append(local_path)
except requests.exceptions.RequestException as e:
logger.error(f"Error downloading image {url}: {e}")
return image_list
def download_images_to_json(image_urls: list):
results = []
for url in image_urls:
try:
response = requests.get(url)
response.raise_for_status()
json_resp = response.json()
results.append(json_resp['info'])
except requests.exceptions.RequestException as e:
logger.error(f"Error downloading image {url}: {e}")
return results
def download_images_to_pil(image_urls: list):
image_list = []
for url in image_urls:
try:
response = requests.get(url)
response.raise_for_status()
from PIL import Image
import io
pil_image = Image.open(io.BytesIO(response.content))
image_list.append(pil_image)
except requests.exceptions.RequestException as e:
logger.error(f"Error downloading image {url}: {e}")
return image_list
def get_model_list_by_type(model_type, username=""):
api_gateway_url = get_variable_from_json('api_gateway_url')
api_key = get_variable_from_json('api_token')
# check if api_gateway_url and api_key are set
if api_gateway_url is None or api_key is None:
logger.info("api_gateway_url or api_key is not set")
return []
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
url = api_gateway_url + f"checkpoints?status=Active"
if isinstance(model_type, list):
url += "&types=" + "&types=".join(model_type)
else:
url += f"&types={model_type}"
try:
encode_type = "utf-8"
response = requests.get(url=url, headers={
'x-api-key': api_key,
'Authorization': f'Bearer {base64.b16encode(username.encode(encode_type)).decode(encode_type)}'
})
response.raise_for_status()
json_response = response.json()
if "checkpoints" not in json_response.keys():
return []
checkpoint_list = []
for ckpt in json_response["checkpoints"]:
if "name" not in ckpt:
continue
if ckpt["name"] is None:
continue
if ckpt["type"] is None:
continue
ckpt_type = ckpt["type"]
for ckpt_name in ckpt["name"]:
ckpt_s3_pos = f"{ckpt['s3Location']}/{ckpt_name}"
checkpoint_info[ckpt_type][ckpt_name] = ckpt_s3_pos
checkpoint_list.append(ckpt_name)
unique_list = list(set(checkpoint_list))
return unique_list
except Exception as e:
logger.error(f"Error fetching model list: {e}")
return []
def get_checkpoints_by_type(model_type):
url = "checkpoints?status=Active"
if isinstance(model_type, list):
url += "&types=" + "&types=".join(model_type)
else:
url += f"&types={model_type}"
try:
response = server_request(url)
response.raise_for_status()
json_response = response.json()
if "checkpoints" not in json_response.keys():
return []
checkpoint_dict = {}
for ckpt in json_response["checkpoints"]:
if "name" not in ckpt:
continue
if ckpt["name"] is None:
continue
create_time = ckpt['created']
created = datetime.fromtimestamp(create_time)
for ckpt_name in ckpt["name"]:
checkpoint = [ckpt_name, created]
if ckpt_name not in checkpoint_dict:
checkpoint_dict[ckpt_name] = checkpoint
checkpoint_list = list(checkpoint_dict.values())
return checkpoint_list
except Exception as e:
logging.error(f"Error fetching checkpoints list: {e}")
return []
def update_sd_checkpoints(username):
model_type = ["Stable-diffusion"]
return get_model_list_by_type(model_type)
def get_texual_inversion_list():
model_type = "embeddings"
return get_model_list_by_type(model_type)
def get_lora_list():
model_type = "Lora"
return get_model_list_by_type(model_type)
def get_hypernetwork_list():
model_type = "hypernetworks"
return get_model_list_by_type(model_type)
def get_controlnet_model_list():
model_type = "ControlNet"
return get_model_list_by_type(model_type)
def refresh_all_models(username):
api_gateway_url = get_variable_from_json('api_gateway_url')
api_key = get_variable_from_json('api_token')
encode_type = "utf-8"
try:
for rp, name in zip(checkpoint_type, checkpoint_name):
url = api_gateway_url + f"checkpoints?status=Active&types={rp}"
response = requests.get(url=url, headers={
'x-api-key': api_key,
'Authorization': f'Bearer {base64.b16encode(username.encode(encode_type)).decode(encode_type)}',
})
json_response = response.json()
logger.debug(f"response url json for model {rp} is {json_response}")
checkpoint_info[rp] = {}
if "checkpoints" not in json_response.keys():
continue
for ckpt in json_response["checkpoints"]:
if "name" not in ckpt:
continue
if ckpt["name"] is None:
continue
ckpt_type = ckpt["type"]
checkpoint_info[ckpt_type] = {}
for ckpt_name in ckpt["name"]:
ckpt_s3_pos = f"{ckpt['s3Location']}/{ckpt_name.split(os.sep)[-1]}"
checkpoint_info[ckpt_type][ckpt_name] = ckpt_s3_pos
except Exception as e:
logger.error(f"Error refresh all models: {e}")
def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path,
controlnet_model_path, vae_path, pr: gradio.Request):
log = "start upload model to s3:"
local_paths = [sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path, controlnet_model_path,
vae_path]
logger.info(f"Refresh checkpoints before upload to get rid of duplicate uploads...")
refresh_all_models(pr.username)
for lp, rp in zip(local_paths, checkpoint_type):
if lp == "" or not lp:
continue
logger.debug(f"lp is {lp}")
model_name = lp.split(os.sep)[-1]
exist_model_list = list(checkpoint_info[rp].keys())
if model_name in exist_model_list:
logger.info(f"!!!skip to upload duplicate model {model_name}")
continue
part_size = 1000 * 1024 * 1024
file_size = os.stat(lp)
parts_number = math.ceil(file_size.st_size / part_size)
logger.info(f'!!!!!!!!!!{file_size} {parts_number}')
# local_tar_path = f'{model_name}.tar'
local_tar_path = model_name
payload = {
"checkpoint_type": rp,
"filenames": [{
"filename": local_tar_path,
"parts_number": parts_number
}],
"params": {
"message": "placeholder for chkpts upload test",
"creator": pr.username
}
}
api_gateway_url = get_variable_from_json('api_gateway_url')
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
api_key = get_variable_from_json('api_token')
logger.info(f'!!!!!!api_gateway_url {api_gateway_url}')
url = str(api_gateway_url) + "checkpoint"
logger.debug(f"Post request for upload s3 presign url: {url}")
response = requests.post(url=url, json=payload, headers={'x-api-key': api_key})
try:
response.raise_for_status()
json_response = response.json()
logger.debug(f"Response json {json_response}")
s3_base = json_response["checkpoint"]["s3_location"]
checkpoint_id = json_response["checkpoint"]["id"]
logger.debug(f"Upload to S3 {s3_base}")
logger.debug(f"Checkpoint ID: {checkpoint_id}")
s3_signed_urls_resp = json_response["s3PresignUrl"][local_tar_path]
# Upload src model to S3.
if rp != "embeddings":
local_model_path_in_repo = os.sep.join(['models', rp, model_name])
else:
local_model_path_in_repo = os.sep.join([rp, model_name])
logger.debug("Pack the model file.")
cp(lp, local_model_path_in_repo, recursive=True)
if rp == "Stable-diffusion":
model_yaml_name = model_name.split('.')[0] + ".yaml"
local_model_yaml_path = os.sep.join([*lp.split(os.sep)[:-1], model_yaml_name])
local_model_yaml_path_in_repo = os.sep.join(["models", rp, model_yaml_name])
if os.path.isfile(local_model_yaml_path):
cp(local_model_yaml_path, local_model_yaml_path_in_repo, recursive=True)
tar(mode='c', archive=local_tar_path,
sfiles=[local_model_path_in_repo, local_model_yaml_path_in_repo], verbose=True)
else:
tar(mode='c', archive=local_tar_path, sfiles=[local_model_path_in_repo], verbose=True)
else:
tar(mode='c', archive=local_tar_path, sfiles=[local_model_path_in_repo], verbose=True)
multiparts_tags = upload_multipart_files_to_s3_by_signed_url(
local_tar_path,
s3_signed_urls_resp,
part_size
)
payload = {
"checkpoint_id": checkpoint_id,
"status": "Active",
"multi_parts_tags": {local_tar_path: multiparts_tags}
}
# Start creating model on cloud.
response = requests.put(url=url, json=payload, headers={'x-api-key': api_key})
s3_input_path = s3_base
logger.debug(response)
log = f"\n finish upload {local_tar_path} to {s3_base}"
# os.system(f"rm {local_tar_path}")
rm(local_tar_path, recursive=True)
except Exception as e:
logger.error(f"fail to upload model {lp}, error: {e}")
logger.debug(f"Refresh checkpoints after upload...")
refresh_all_models(pr.username)
return log, None, None, None, None, None, None
def sagemaker_upload_model_s3_local():
log = "Start upload:"
return log
def sagemaker_upload_model_s3_url(model_type: str, url_list: str, params: str, pr: gradio.Request):
model_type = modelTypeMap.get(model_type)
if not model_type:
return "Please choose the model type."
url_pattern = r'(https?|ftp)://[^\s/$.?#].[^\s]*'
if re.match(f'^{url_pattern}$', url_list):
url_list = url_list.split(',')
else:
return "Please fill in right url list."
if params:
params_dict = json.loads(params)
else:
params_dict = {}
params_dict['creator'] = pr.username
body_params = {'checkpointType': model_type, 'modelUrl': url_list, 'params': params_dict}
response = server_request_post('upload_checkpoint', body_params)
response_data = response.json()
logging.info(f"sagemaker_upload_model_s3_url response:{response_data}")
log = "uploading……"
if 'checkpoint' in response_data:
if response_data['checkpoint'].get('status') == 'Active':
log = "upload success!"
return log
def generate_on_cloud(sagemaker_endpoint):
logger.info(f"checkpiont_info {checkpoint_info}")
logger.info(f"sagemaker endpoint {sagemaker_endpoint}")
text = "failed to check endpoint"
return plaintext_to_html(text)
# create a global event loop and apply the patch to allow nested event loops in single thread
loop = asyncio.get_event_loop()
nest_asyncio.apply()
MAX_RUNNING_LIMIT = 10
def async_loop_wrapper(f):
global loop
# check if there are any running or queued tasks inside the event loop
if loop.is_running():
# Calculate the number of running tasks
while len([task for task in asyncio.all_tasks(loop) if not task.done()]) > MAX_RUNNING_LIMIT:
logger.debug(f'Waiting for {MAX_RUNNING_LIMIT} running tasks to complete')
time.sleep(1)
else:
# check if loop is closed and create a new one
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# log this event since it should never happen
logger.debug('Event loop was closed, created a new one')
# Add new task to the event loop
result = loop.run_until_complete(f())
return result
def async_loop_wrapper_with_input(sagemaker_endpoint, type):
global loop
# check if there are any running or queued tasks inside the event loop
if loop.is_running():
# Calculate the number of running tasks
while len([task for task in asyncio.all_tasks(loop) if not task.done()]) > MAX_RUNNING_LIMIT:
logger.debug(f'Waiting for {MAX_RUNNING_LIMIT} running tasks to complete')
time.sleep(1)
else:
# check if loop is closed and create a new one
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# log this event since it should never happen
logger.debug('Event loop was closed, created a new one')
# Add new task to the event loop
result = loop.run_until_complete(call_remote_inference(sagemaker_endpoint, type))
return result
def call_interrogate_clip(sagemaker_endpoint, init_img, sketch, init_img_with_mask, inpaint_color_sketch,
init_img_inpaint, init_mask_inpaint):
return async_loop_wrapper_with_input(sagemaker_endpoint, 'interrogate_clip')
def call_interrogate_deepbooru(sagemaker_endpoint, init_img, sketch, init_img_with_mask, inpaint_color_sketch,
init_img_inpaint, init_mask_inpaint):
return async_loop_wrapper_with_input(sagemaker_endpoint, 'interrogate_deepbooru')
async def call_remote_inference(sagemaker_endpoint, type):
logger.debug(f"chosen ep {sagemaker_endpoint}")
logger.debug(f"inference type is {type}")
if sagemaker_endpoint == '':
image_list = [] # Return an empty list if selected_value is None
info_text = ''
infotexts = "Failed! Please choose the endpoint in 'InService' states "
return image_list, info_text, plaintext_to_html(infotexts)
elif sagemaker_endpoint == 'FAILURE':
image_list = [] # Return an empty list if selected_value is None
info_text = ''
infotexts = "Failed upload the config to cloud "
return image_list, info_text, plaintext_to_html(infotexts)
sagemaker_endpoint_status = sagemaker_endpoint.split("+")[1]
if sagemaker_endpoint_status != "InService":
image_list = [] # Return an empty list if selected_value is None
info_text = ''
infotexts = "Failed! Please choose the endpoint in 'InService' states "
return image_list, info_text, plaintext_to_html(infotexts)
api_gateway_url = get_variable_from_json('api_gateway_url')
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
api_key = get_variable_from_json('api_token')
# stage 2: inference using endpoint_name
headers = {
"x-api-key": api_key,
"Content-Type": "application/json"
}
checkpoint_info['sagemaker_endpoint'] = sagemaker_endpoint.split("+")[0]
payload = checkpoint_info
payload['task_type'] = type
logger.info(f"checkpointinfo is {payload}")
inference_url = f"{api_gateway_url}inference/run-sagemaker-inference"
response = requests.post(inference_url, json=payload, headers=headers)
logger.info(f"Raw server response: {response.text}")
try:
r = response.json()
except JSONDecodeError as e:
logger.error(f"Failed to decode JSON response: {e}")
logger.error(f"Raw server response: {response.text}")
else:
inference_id = r.get('inference_id') # Assuming the response contains 'inference_id' field
try:
return process_result_by_inference_id(inference_id)
except Exception as e:
logger.error(f"Failed to get inference job {inference_id}, error: {e}")
def process_result_by_inference_id(inference_id):
image_list = [] # Return an empty list if selected_value is None
info_text = ''
infotexts = f"Inference id is {inference_id}, please check all historical inference result in 'Inference Job' dropdown list"
json_list = []
prompt_txt = ''
resp = get_inference_job(inference_id)
if resp is None:
logger.info(f"get_inference_job resp is null")
return image_list, info_text, plaintext_to_html(infotexts), infotexts
else:
logger.debug(f"get_inference_job resp is {resp}")
if 'taskType' not in resp:
raise Exception(resp)
if resp['taskType'] in ['txt2img', 'img2img', 'interrogate_clip', 'interrogate_deepbooru']:
while resp and resp['status'] == "inprogress":
time.sleep(3)
resp = get_inference_job(inference_id)
if resp is None:
logger.info(f"get_inference_job resp is null.")
return image_list, info_text, plaintext_to_html(infotexts), infotexts
if resp['status'] == "failed":
infotexts = f"Inference job {inference_id} is failed, error message: {resp['sagemakerRaw']}"
return image_list, info_text, plaintext_to_html(infotexts), infotexts
elif resp['status'] == "succeed":
if resp['taskType'] in ['interrogate_clip', 'interrogate_deepbooru']:
prompt_txt = resp['caption']
# return with default value, including image_list, info_text, infotexts
return image_list, info_text, plaintext_to_html(infotexts), prompt_txt
images = get_inference_job_image_output(inference_id.strip())
inference_param_json_list = get_inference_job_param_output(inference_id)
# todo: these not need anymore
if resp['taskType'] in ['txt2img', 'img2img']:
image_list = download_images_to_pil(images)
json_file = download_images_to_json(inference_param_json_list)[0]
if json_file:
info_text = json_file
infotexts = f"Inference id is {inference_id}\n" + json.loads(info_text)["infotexts"][0]
else:
logger.debug(f"File {json_file} does not exist.")
info_text = 'something wrong when trying to download the inference parameters'
infotexts = info_text
return image_list, info_text, plaintext_to_html(infotexts), infotexts
else:
logger.debug(f"inference job status is {resp['status']}")
return image_list, info_text, plaintext_to_html(infotexts), infotexts
else:
return image_list, info_text, plaintext_to_html(infotexts), infotexts
def modelmerger_on_cloud_func(primary_model_name, secondary_model_name, teritary_model_name):
logger.debug(f"function under development, current checkpoint_info is {checkpoint_info}")
api_gateway_url = get_variable_from_json('api_gateway_url')
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
api_key = get_variable_from_json('api_token')
if api_gateway_url is None:
logger.debug(f"modelmerger: failed to get the api-gateway url, can not fetch remote data")
return []
modelmerge_url = f"{api_gateway_url}inference/run-model-merge"
payload = {
"primary_model_name": primary_model_name,
"secondary_model_name": secondary_model_name,
"tertiary_model_name": teritary_model_name
}
headers = {
"x-api-key": api_key,
"Content-Type": "application/json"
}
response = requests.post(modelmerge_url, json=payload, headers=headers)
try:
r = response.json()
except JSONDecodeError as e:
logger.error(f"Failed to decode JSON response: {e}")
logger.error(f"Raw server response: {response.text}")
else:
logger.debug(f"response for rest api {r}")
# def txt2img_config_save():
# # placeholder for saving txt2img config
# pass
def displayEndpointInfo(input_string: str):
logger.debug(f"selected value is {input_string}")
if not input_string:
return
parts = input_string.split('+')
if len(parts) < 2:
return plaintext_to_html("")
endpoint_job_id, status = parts[0], parts[1]
if status == 'failed':
response = server_request(f'inference/get-endpoint-deployment-job?jobID={endpoint_job_id}')
# Do something with the response
r = response.json()
if "error" in r:
return plaintext_to_html(r["error"])
else:
return plaintext_to_html(r["EndpointDeploymentJobId"])
else:
return plaintext_to_html("")
def update_txt2imgPrompt_from_TextualInversion(selected_items, txt2img_prompt):
return update_txt2imgPrompt_from_model_select(selected_items, txt2img_prompt, 'embeddings', False)
def update_txt2imgPrompt_from_Hypernetworks(selected_items, txt2img_prompt):
return update_txt2imgPrompt_from_model_select(selected_items, txt2img_prompt, 'hypernetworks', True)
def update_txt2imgPrompt_from_Lora(selected_items, txt2img_prompt):
return update_txt2imgPrompt_from_model_select(selected_items, txt2img_prompt, 'Lora', True)
def update_txt2imgPrompt_from_model_select(selected_items, txt2img_prompt, model_name='embeddings',
with_angle_brackets=False):
logger.debug(selected_items) # example ['FastNegativeV2.pt']
logger.debug(txt2img_prompt)
logger.debug(get_model_list_by_type('embeddings'))
full_dropdown_items = get_model_list_by_type(model_name) # example ['FastNegativeV2.pt', 'okuryl3nko.pt']
# Remove extensions from selected_items and full_dropdown_items
selected_items = [item.split('.')[0] for item in selected_items]
full_dropdown_items = [item.split('.')[0] for item in full_dropdown_items]
# Loop over each item in full_dropdown_items and remove it from txt2img_prompt
type_str = ''
if model_name == 'Lora':
type_str = 'lora:'
elif model_name == 'hypernetworks':
type_str = 'hypernet:'
for item in full_dropdown_items:
if with_angle_brackets:
txt2img_prompt = re.sub(f'<{type_str}{item}:\d+>', "", txt2img_prompt).strip()
else:
txt2img_prompt = txt2img_prompt.replace(item, "").strip()
# Loop over each item in selected_items and append it to txt2img_prompt
for item in selected_items:
if with_angle_brackets:
txt2img_prompt += ' ' + '<' + type_str + item + ':1>'
else:
txt2img_prompt += ' ' + item
# Remove any leading or trailing whitespace
txt2img_prompt = txt2img_prompt.strip()
return txt2img_prompt
def fake_gan(selected_value, original_prompt):
logger.debug(f"selected value is {selected_value}")
logger.debug(f"original prompt is {original_prompt}")
if selected_value is not None:
delimiter = "-->"
parts = selected_value.split(delimiter)
# Extract the InferenceJobId value
inference_job_id = parts[3].strip()
inference_job_status = parts[2].strip()
inference_job_taskType = parts[1].strip()
if inference_job_status == 'inprogress':
return [], [], plaintext_to_html('inference still in progress'), ''
if inference_job_taskType in ["txt2img", "img2img"]:
prompt_txt = original_prompt
# output directory mapping to task type
images = get_inference_job_image_output(inference_job_id.strip())
inference_param_json_list = get_inference_job_param_output(inference_job_id)
image_list = download_images_to_pil(images)
json_file = download_images_to_json(inference_param_json_list)[0]
if json_file:
info_text = json_file
infotexts = f"Inference id is {inference_job_id}\n" + json.loads(info_text)["infotexts"][0]
else:
logger.debug(f"File {json_file} does not exist.")
info_text = 'something wrong when trying to download the inference parameters'
infotexts = info_text
elif inference_job_taskType in ["interrogate_clip", "interrogate_deepbooru"]:
job_status = get_inference_job(inference_job_id)
logger.debug(job_status)
caption = job_status['caption']
prompt_txt = caption
image_list = [] # Return an empty list if selected_value is None
json_list = []
info_text = ''
infotexts = ''
else:
prompt_txt = original_prompt
image_list = [] # Return an empty list if selected_value is None
json_list = []
info_text = ''
infotexts = ''
return image_list, info_text, plaintext_to_html(infotexts), prompt_txt
def init_refresh_resource_list_from_cloud(username):
logger.debug(f"start refreshing resource list from cloud")
if get_variable_from_json('api_gateway_url') is not None:
if not cloud_auth_manager.enableAuth:
refresh_all_models(username)
else:
logger.debug('auth enabled, not preload load any model')
else:
logger.debug(f"there is no api-gateway url and token in local file,")
def on_txt_time_change(start_time, end_time):
logger.debug(f"!!!!!!!!!on_txt_time_change!!!!!!{start_time},{end_time}")
global start_time_picker_txt_value
global end_time_picker_txt_value
start_time_picker_txt_value = start_time
end_time_picker_txt_value = end_time
return query_inference_job_list(txt_task_type, txt_status, txt_endpoint, txt_checkpoint, "txt2img")
def on_img_time_change(start_time, end_time):
logger.debug(f"!!!!!!!!!on_img_time_change!!!!!!{start_time},{end_time}")
global start_time_picker_img_value
global end_time_picker_img_value
start_time_picker_img_value = start_time
end_time_picker_img_value = end_time
return query_inference_job_list(img_task_type, img_status, img_endpoint, img_checkpoint, "img2img")
def load_inference_job_list(target_task_type, username, usertoken):
inference_jobs = [None_Option_For_On_Cloud_Model]
inferences_jobs_list = api_manager.list_all_inference_jobs_on_cloud(username, usertoken)
temp_list = []
for obj in inferences_jobs_list:
if obj.get('completeTime') is None:
complete_time = obj.get('startTime')
else:
complete_time = obj.get('completeTime')
status = obj.get('status')
task_type = obj.get('taskType', 'txt2img')
inference_job_id = obj.get('InferenceJobId')
# if filter_checkbox and task_type not in selected_types:
# continue
if target_task_type == task_type:
temp_list.append((complete_time, f"{complete_time}-->{task_type}-->{status}-->{inference_job_id}"))
# Sort the list based on completeTime in ascending order
sorted_list = sorted(temp_list, key=lambda x: x[0], reverse=False)
# Append the sorted combined strings to the txt2img_inference_job_ids list
for item in sorted_list:
inference_jobs.append(item[1])
return inference_jobs
def load_model_list(username, user_token):
models_on_cloud = []
if 'sd_model_checkpoint' in opts.quicksettings_list:
models_on_cloud = [None_Option_For_On_Cloud_Model]
models_on_cloud += list(set([model['name'] for model in api_manager.list_models_on_cloud(username, user_token)]))
return models_on_cloud
def load_lora_models(username, user_token):
return list(set([model['name'] for model in api_manager.list_models_on_cloud(username, user_token, types='Lora')]))
def load_hypernetworks_models(username, user_token):
return list(set([model['name'] for model in api_manager.list_models_on_cloud(username, user_token, types='hypernetworks')]))
def load_vae_list(username, user_token):
vae_model_on_cloud = ['Automatic', 'None']
vae_model_on_cloud += list(set([model['name'] for model in api_manager.list_models_on_cloud(username, user_token, types='VAE')]))
return vae_model_on_cloud
def load_controlnet_list(username, user_token):
controlnet_model_on_cloud = ['None']
controlnet_model_on_cloud += list(set([model['name'] for model in api_manager.list_models_on_cloud(username, user_token, types='ControlNet')]))
return controlnet_model_on_cloud
def load_xyz_controlnet_list(username, user_token):
controlnet_model_on_cloud = ['None']
controlnet_model_on_cloud += list(
set([os.path.splitext(model['name'])[0] for model in api_manager.list_models_on_cloud(username, user_token, types='ControlNet')]))
return controlnet_model_on_cloud
def load_embeddings_list(username, user_token):
# vae_model_on_cloud = ['None']
vae_model_on_cloud = list(set([model['name'] for model in api_manager.list_models_on_cloud(username, user_token, types='embeddings')]))
return vae_model_on_cloud
def create_ui(is_img2img):
global txt2img_gallery, txt2img_generation_info
import modules.ui
# init_refresh_resource_list_from_cloud()
with gr.Blocks() as sagemaker_inference_tab:
gr.HTML('Amazon SageMaker Inference
')
sagemaker_html_log = gr.HTML(elem_id=f'html_log_sagemaker')
inference_task_type = 'txt2img' if not is_img2img else 'img2img'
with gr.Column():
with gr.Row():
lora_and_hypernet_models_state = gr.State({})
sd_model_on_cloud_dropdown = gr.Dropdown(choices=[], value=None_Option_For_On_Cloud_Model,
label='Stable Diffusion Checkpoint Used on Cloud')
create_refresh_button_by_user(sd_model_on_cloud_dropdown,
lambda *args: None,
lambda username: {
'choices': load_model_list(username, username)
}, 'refresh_cloud_model_down')
with gr.Row():
sd_vae_on_cloud_dropdown = gr.Dropdown(choices=[], value='Automatic',
label='SD Vae on Cloud')
create_refresh_button_by_user(sd_vae_on_cloud_dropdown,
lambda *args: None,
lambda username: {
'choices': load_vae_list(username, username)
}, 'refresh_cloud_vae_down')
with gr.Row(visible=is_img2img):
gr.HTML('
')
# with gr.Row(visible=is_img2img):
# global generate_on_cloud_button_with_js
# # if not is_img2img:
# # generate_on_cloud_button_with_js = gr.Button(value="Generate on Cloud", variant='primary', elem_id="generate_on_cloud_with_cloud_config_button",queue=True, show_progress=True)
# global generate_on_cloud_button_with_js_img2img
# global interrogate_clip_on_cloud_button
# global interrogate_deep_booru_on_cloud_button
#
# interrogate_clip_on_cloud_button = gr.Button(value="Interrogate CLIP", variant='primary',
# elem_id="interrogate_clip_on_cloud_button", visible=False)
# interrogate_deep_booru_on_cloud_button = gr.Button(value="Interrogte DeepBooru", variant='primary',
# elem_id="interrogate_deep_booru_on_cloud_button", visible=False)
# with gr.Row():
# gr.HTML('
')
with gr.Row():
global inference_job_dropdown
# global txt2img_inference_job_ids
inference_job_dropdown = gr.Dropdown(choices=[], value=None_Option_For_On_Cloud_Model,
label="Inference Job: Time-Type-Status-Uuid")
create_refresh_button_by_user(inference_job_dropdown,
lambda *args: None,
lambda username: {
'choices': load_inference_job_list(inference_task_type, username, username)
}, 'refresh_inference_job_down')
# inference_job_dropdown = gr.Dropdown(choices=txt2img_inference_job_ids,
# label="Inference Job: Time-Type-Status-Uuid",
# elem_id="txt2img_inference_job_ids_dropdown"
# )
# txt2img_inference_job_ids_refresh_button = create_refresh_button(inference_job_dropdown,
# query_inference_job_list,
# lambda: {
# "choices": txt2img_inference_job_ids,
# "value": None},
# "refresh_txt2img_inference_job_ids")
# fixme: inference filters need to be fixed
# with gr.Row():
# inference_job_filter = gr.Checkbox(
# label="Advanced Inference Job filter", value=False, visible=True
# )
# inference_job_page = gr.Checkbox(label="Show All(unchecked: max 10 items)",
# elem_id="inference_job_page_checkbox", value=False)
# with gr.Row(variant='panel', visible=False) as filter_row:
# with gr.Column(scale=1):
# gr.HTML(value="Inference Job type filters")
# with gr.Column(scale=2):
# with gr.Row():
# task_type_choices = ["txt2img", "img2img", "interrogate_clip", "interrogate_deepbooru"]
# task_type_dropdown = gr.Dropdown(label="Task Type", choices=task_type_choices,
# elem_id="task_type_ids_dropdown")
# status_choices = ["succeed", "inprogress", "failed"]
# status_dropdown = gr.Dropdown(label="Status", choices=status_choices,
# elem_id="task_status_dropdown")
# # with gr.Row():
# # sagemaker_endpoint_filter = gr.Dropdown(api_manager.list_all_sagemaker_endpoints(),
# # label="SageMaker Endpoint",
# # elem_id="sagemaker_endpoint_dropdown")
# # modules.ui.create_refresh_button(sagemaker_endpoint_filter, lambda: None,
# # lambda: {"choices": api_manager.list_all_sagemaker_endpoints()},
# # "refresh_sagemaker_endpoints")
#
# with gr.Row():
# sd_checkpoint_filter = gr.Dropdown(label="Checkpoint", choices=sorted(update_sd_checkpoints()),
# elem_id="stable_diffusion_checkpoint_dropdown")
# modules.ui.create_refresh_button(sd_checkpoint_filter, update_sd_checkpoints,
# lambda: {"choices": sorted(update_sd_checkpoints())},
# "refresh_sd_checkpoints")
# if is_img2img:
# with gr.Row():
# start_time_picker_img = gr.HTML(elem_id="start_timepicker_img_e",
# value="""
#
# Start Time
#
#
# """)
# end_time_picker_img = gr.HTML(elem_id="end_timepicker_img_e",
# value="""
#
# End Time
#
#
# """)
# start_time_picker_img_hidden = gr.Button(elem_id="start_time_picker_img_hidden",
# visible=True)
# end_time_picker_img_hidden = gr.Button(elem_id="end_time_picker_img_hidden",
# visible=True)
# start_time_picker_img_hidden.click(fn=on_img_time_change,
# _js='get_time_img_value',
# inputs=[start_time_picker_img, end_time_picker_img],
# outputs=inference_job_dropdown)
# end_time_picker_img_hidden.click(fn=on_img_time_change,
# _js='get_time_img_value',
# inputs=[start_time_picker_img, end_time_picker_img],
# outputs=inference_job_dropdown
# )
# task_type_dropdown.change(fn=query_img_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter,
# sd_checkpoint_filter], outputs=inference_job_dropdown)
# status_dropdown.change(fn=query_img_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown, sagemaker_endpoint_filter,
# sd_checkpoint_filter], outputs=inference_job_dropdown)
#
# sagemaker_endpoint_filter.change(fn=query_img_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter,
# sd_checkpoint_filter], outputs=inference_job_dropdown)
#
# sd_checkpoint_filter.change(fn=query_img_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter,
# sd_checkpoint_filter], outputs=inference_job_dropdown)
# else:
# with gr.Row():
# start_time_picker_text = gr.HTML(elem_id="start_timepicker_text_e",
# value="""
#
# Start Time
#
#
# """)
# end_time_picker_text = gr.HTML(elem_id="end_timepicker_text_e",
# value="""
#
# End Time
#
#
# """)
# start_time_picker_button_hidden = gr.Button(elem_id="start_time_picker_button_hidden",
# visible=False)
# end_time_picker_button_hidden = gr.Button(elem_id="end_time_picker_button_hidden",
# visible=False)
# start_time_picker_button_hidden.click(fn=on_txt_time_change,
# _js='get_time_button_value',
# inputs=[start_time_picker_text, end_time_picker_text],
# outputs=inference_job_dropdown)
# end_time_picker_button_hidden.click(fn=on_txt_time_change,
# _js='get_time_button_value',
# inputs=[start_time_picker_text, end_time_picker_text],
# outputs=inference_job_dropdown)
# task_type_dropdown.change(fn=query_txt_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter, sd_checkpoint_filter],
# outputs=inference_job_dropdown)
# status_dropdown.change(fn=query_txt_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter,
# sd_checkpoint_filter], outputs=inference_job_dropdown)
# sagemaker_endpoint_filter.change(fn=query_txt_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter,
# sd_checkpoint_filter], outputs=inference_job_dropdown)
# sd_checkpoint_filter.change(fn=query_txt_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter, sd_checkpoint_filter],
# outputs=inference_job_dropdown)
#
# def toggle_new_rows(create_from):
# global start_time_picker_txt_value
# start_time_picker_txt_value = None
# global end_time_picker_txt_value
# end_time_picker_txt_value = None
# return [gr.update(visible=create_from), None, None, None, None]
#
# inference_job_filter.change(
# fn=toggle_new_rows,
# inputs=[inference_job_filter],
# outputs=[filter_row, task_type_dropdown, status_dropdown, sagemaker_endpoint_filter,
# sd_checkpoint_filter],
# )
# hidden_check_type = gr.Textbox(elem_id="hidden_check_type", value=is_img2img, visible=False)
# inference_job_page.change(fn=query_page_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter,
# sd_checkpoint_filter, hidden_check_type,
# inference_job_page],
#
# outputs=inference_job_dropdown)
def setup_inference_for_plugin(pr: gr.Request):
models_on_cloud = load_model_list(pr.username, pr.username)
vae_model_on_cloud = load_vae_list(pr.username, pr.username)
lora_models_on_cloud = load_lora_models(username=pr.username, user_token=pr.username)
hypernetworks_models_on_cloud = load_hypernetworks_models(pr.username, pr.username)
controlnet_list = load_controlnet_list(pr.username, pr.username)
controlnet_xyz_list = load_xyz_controlnet_list(pr.username, pr.username)
inference_jobs = load_inference_job_list(inference_task_type, pr.username, pr.username)
lora_hypernets = {
'lora': lora_models_on_cloud,
'hypernet': hypernetworks_models_on_cloud,
'controlnet': controlnet_list,
'controlnet_xyz': controlnet_xyz_list,
'vae': vae_model_on_cloud,
'sd': models_on_cloud,
}
return lora_hypernets, \
gr.update(choices=models_on_cloud, value=models_on_cloud[0] if models_on_cloud and len(models_on_cloud) > 0 else None_Option_For_On_Cloud_Model), \
gr.update(choices=inference_jobs), \
gr.update(choices=vae_model_on_cloud)
sagemaker_inference_tab.load(fn=setup_inference_for_plugin, inputs=[],
outputs=[
lora_and_hypernet_models_state,
sd_model_on_cloud_dropdown,
inference_job_dropdown,
sd_vae_on_cloud_dropdown
])
with gr.Group():
with gr.Accordion("Open for Checkpoint Merge in the Cloud!", visible=False, open=False):
sagemaker_html_log = gr.HTML(elem_id=f'html_log_sagemaker')
with FormRow(elem_id="modelmerger_models_in_the_cloud"):
global primary_model_name
primary_model_name = gr.Dropdown(elem_id="modelmerger_primary_model_name_in_the_cloud",
label="Primary model (A) in the cloud")
create_refresh_button_by_user(primary_model_name, lambda *args: None,
lambda username: {"choices": sorted(update_sd_checkpoints(username))},
"refresh_checkpoint_A_in_the_cloud")
global secondary_model_name
secondary_model_name = gr.Dropdown(elem_id="modelmerger_secondary_model_name_in_the_cloud",
label="Secondary model (B) in the cloud")
create_refresh_button_by_user(secondary_model_name, lambda *args: None,
lambda username: {"choices": sorted(update_sd_checkpoints(username))},
"refresh_checkpoint_B_in_the_cloud")
global tertiary_model_name
tertiary_model_name = gr.Dropdown(elem_id="modelmerger_tertiary_model_name_in_the_cloud",
label="Tertiary model (C) in the cloud")
create_refresh_button_by_user(tertiary_model_name, lambda *args: None,
lambda username: {"choices": sorted(update_sd_checkpoints(username))},
"refresh_checkpoint_C_in_the_cloud")
with gr.Row():
global modelmerger_merge_on_cloud
modelmerger_merge_on_cloud = gr.Button(elem_id="modelmerger_merge_in_the_cloud", value="Merge on Cloud",
variant='primary')
return sd_model_on_cloud_dropdown, sd_vae_on_cloud_dropdown, inference_job_dropdown, primary_model_name, secondary_model_name, tertiary_model_name, modelmerger_merge_on_cloud, lora_and_hypernet_models_state