import logging import os import html import boto3 import time import json import gradio import requests import gradio as gr from aws_extension.cloud_api_manager.api_logger import ApiLogger from aws_extension.constant import MODEL_TYPE import utils from aws_extension.auth_service.simple_cloud_auth import cloud_auth_manager from aws_extension.cloud_api_manager.api_manager import api_manager from aws_extension.cloud_api_manager.api import api from aws_extension.sagemaker_ui_utils import create_refresh_button_by_user from modules.shared import opts from modules.ui_components import FormRow from utils import get_variable_from_json, upload_multipart_files_to_s3_by_signed_url, has_config from requests.exceptions import JSONDecodeError from datetime import datetime import math import re from modules.ui_components import ToolButton import asyncio import nest_asyncio from utils import cp, tar, rm logger = logging.getLogger(__name__) logger.setLevel(utils.LOGGING_LEVEL) None_Option_For_On_Cloud_Model = "don't use on cloud inference" None_Option_For_Infer_Job = "No Selected" inference_job_dropdown = None embedding_dropdown = None hypernet_dropdown = None lora_dropdown = None # sagemaker_endpoint = None modelmerger_merge_on_cloud = None interrogate_clip_on_cloud_button = None interrogate_deep_booru_on_cloud_button = None primary_model_name = None secondary_model_name = None tertiary_model_name = None # TODO: convert to dynamically init the following variables txt2img_inference_job_ids = [] sd_checkpoints = [] textual_inversion_list = [] lora_list = [] hyperNetwork_list = [] ControlNet_model_list = [] # Initial checkpoints information checkpoint_info = {} checkpoint_type = ["Stable-diffusion", "embeddings", "Lora", "hypernetworks", "ControlNet", "VAE"] checkpoint_name = ["stable_diffusion", "embeddings", "lora", "hypernetworks", "controlnet", "VAE"] stable_diffusion_list = [] embeddings_list = [] hypernetworks_list = [] controlnet_list = [] for ckpt_type, ckpt_name in zip(checkpoint_type, checkpoint_name): checkpoint_info[ckpt_type] = {} # get api_gateway_url api_gateway_url = get_variable_from_json('api_gateway_url') api_key = get_variable_from_json('api_token') start_time_picker_img_value = None end_time_picker_img_value = None start_time_picker_txt_value = None end_time_picker_txt_value = None txt_task_type = None txt_status = None txt_endpoint = None txt_checkpoint = None img_task_type = None img_status = None img_endpoint = None img_checkpoint = None show_all_inference_job = False modelTypeMap = { 'SD Checkpoints': 'Stable-diffusion', 'Textual Inversion': 'embeddings', 'LoRA model': 'Lora', 'ControlNet model': 'ControlNet', 'Hypernetwork': 'hypernetworks', 'VAE': 'VAE' } def plaintext_to_html(text): text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" return text def server_request_get(path, params): api_gateway_url = get_variable_from_json('api_gateway_url') # Check if api_url ends with '/', if not append it if not api_gateway_url.endswith('/'): api_gateway_url += '/' api_key = get_variable_from_json('api_token') headers = { "x-api-key": api_key, "Content-Type": "application/json" } list_endpoint_url = f'{api_gateway_url}{path}' response = requests.get(list_endpoint_url, params=params, headers=headers) return response def get_s3_file_names(bucket, folder): """Get a list of file names from an S3 bucket and folder.""" s3 = boto3.resource('s3') bucket = s3.Bucket(bucket) objects = bucket.objects.filter(Prefix=folder) names = [obj.key for obj in objects] return names def get_current_date(): today = datetime.today() formatted_date = today.strftime('%Y-%m-%d') return formatted_date def server_request(path): api_gateway_url = get_variable_from_json('api_gateway_url') if not has_config(): return [] # Check if api_url ends with '/', if not append it if not api_gateway_url.endswith('/'): api_gateway_url += '/' api_key = get_variable_from_json('api_token') headers = { "x-api-key": api_key, "Content-Type": "application/json" } list_endpoint_url = f'{api_gateway_url}{path}' response = requests.get(list_endpoint_url, headers=headers) return response def datetime_to_short_form(datetime_str): dt = get_strptime(datetime_str) short_form = dt.strftime("%Y-%m-%d-%H-%M-%S") return short_form def query_page_inference_job_list(task_type: str, status: str, endpoint: str, checkpoint: str, is_img2img: bool, show_all: bool): global show_all_inference_job show_all_inference_job = show_all if is_img2img: global img_task_type img_task_type = task_type global img_status img_status = status global img_endpoint img_endpoint = endpoint global img_checkpoint img_checkpoint = checkpoint opt_type = 'img2img' logger.debug(f"{opt_type} {img_task_type} {img_status} {img_endpoint} {img_checkpoint} {show_all}") return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type) else: global txt_task_type txt_task_type = task_type global txt_status txt_status = txt_status global txt_endpoint txt_endpoint = endpoint global txt_checkpoint txt_checkpoint = checkpoint opt_type = 'txt2img' logger.debug(f"{opt_type} {txt_task_type} {txt_status} {txt_endpoint} {txt_checkpoint} {show_all}") return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type) def query_img_inference_job_list(task_type: str, status: str, endpoint: str, checkpoint: str): opt_type = 'img2img' global img_task_type img_task_type = task_type global img_status img_status = status global img_endpoint img_endpoint = endpoint global img_checkpoint img_checkpoint = checkpoint return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type) def query_txt_inference_job_list(task_type: str, status: str, endpoint: str, checkpoint: str): opt_type = 'txt2img' global txt_task_type txt_task_type = task_type global txt_status txt_status = status global txt_endpoint txt_endpoint = endpoint global txt_checkpoint txt_checkpoint = checkpoint return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type) def query_inference_job_list(task_type: str = '', status: str = '', endpoint: str = '', checkpoint: str = '', opt_type: str = ''): logger.debug( f"query_inference_job_list start!!{status},{task_type},{endpoint},{checkpoint},{start_time_picker_txt_value},{end_time_picker_txt_value} {show_all_inference_job}") try: body_params = {} if status: body_params['status'] = status if task_type: body_params['task_type'] = task_type if opt_type == 'txt2img': if start_time_picker_txt_value: body_params['start_time'] = start_time_picker_txt_value if end_time_picker_txt_value: body_params['end_time'] = end_time_picker_txt_value elif opt_type == 'img2img': if start_time_picker_img_value: body_params['start_time'] = start_time_picker_img_value if end_time_picker_img_value: body_params['end_time'] = end_time_picker_img_value if endpoint: endpoint_name_array = endpoint.split("+") if len(endpoint_name_array) > 0: body_params['endpoint'] = endpoint_name_array[0] if checkpoint: body_params['checkpoint'] = checkpoint body_params['limit'] = -1 if show_all_inference_job else 10 response = server_request_get(f'inferences', body_params) r = response.json()['data']['inferences'] logger.debug(r) if r: txt2img_inference_job_ids.clear() # Clear the existing list before appending new values temp_list = [] for obj in r: if obj.get('completeTime') is None: complete_time = obj.get('startTime') else: complete_time = obj.get('completeTime') status = obj.get('status') task_type = obj.get('taskType', 'txt2img') inference_job_id = obj.get('InferenceJobId') combined_string = f"{complete_time}-->{task_type}-->{status}-->{inference_job_id}" temp_list.append((complete_time, combined_string)) # Sort the list based on completeTime in ascending order sorted_list = sorted(temp_list, key=lambda x: x[0], reverse=False) # Append the sorted combined strings to the txt2img_inference_job_ids list for item in sorted_list: txt2img_inference_job_ids.append(item[1]) # inference_job_dropdown.update(choices=txt2img_inference_job_ids) return gr.Dropdown.update(choices=txt2img_inference_job_ids) else: logger.info("The API response is empty.") return gr.Dropdown.update(choices=[]) except Exception as e: logger.error("Exception occurred when fetching inference_job_ids") logger.error(e) return gr.Dropdown.update(choices=[]) def get_inference_job(inference_job_id): url = f'inferences/{inference_job_id}' response = server_request(url) logger.debug(f"get_inference_job response {response}") infer_id = "" if 'data' in response.json(): infer_id = response.json()['data']['InferenceJobId'] api_logger = ApiLogger( action='inference', append=True, infer_id=infer_id ) headers = { "x-api-key": get_variable_from_json('api_token'), "Content-Type": "application/json" } api_gateway_url = get_variable_from_json('api_gateway_url') api_logger.req_log(sub_action="GetInferenceJob", method='GET', path=f"{api_gateway_url}{url}", headers=headers, response=response, desc=f"Get inference job detail from cloud by ID ({inference_job_id}), " f"end request if data.status == succeed, " f"ID from previous step: CreateInference -> data -> inference -> id") if 'data' not in response.json(): raise Exception(response.json()) return response.json()['data'] def download_images(image_urls: list, local_directory: str): if not os.path.exists(local_directory): os.makedirs(local_directory) image_list = [] for url in image_urls: try: response = requests.get(url) response.raise_for_status() image_name = os.path.basename(url).split('?')[0] local_path = os.path.join(local_directory, image_name) with open(local_path, 'wb') as f: f.write(response.content) image_list.append(local_path) except requests.exceptions.RequestException as e: logger.error(f"Error downloading image {url}: {e}") return image_list def download_images_to_json(image_urls: list): results = [] for url in image_urls: try: response = requests.get(url) response.raise_for_status() json_resp = response.json() results.append(json_resp['info']) except requests.exceptions.RequestException as e: logger.error(f"Error downloading image {url}: {e}") return results def download_images_to_pil(image_urls: list): image_list = [] for url in image_urls: try: response = requests.get(url) response.raise_for_status() from PIL import Image import io pil_image = Image.open(io.BytesIO(response.content)) image_list.append(pil_image) except requests.exceptions.RequestException as e: logger.error(f"Error downloading image {url}: {e}") return image_list def get_model_list_by_type(model_type, username=""): api_gateway_url = get_variable_from_json('api_gateway_url') api_key = get_variable_from_json('api_token') # check if api_gateway_url and api_key are set if not has_config(): logger.info("api_gateway_url or api_key is not set") return [] # Check if api_url ends with '/', if not append it if not api_gateway_url.endswith('/'): api_gateway_url += '/' url = api_gateway_url + f"checkpoints?status=Active" if isinstance(model_type, list): url += "&types=" + "&types=".join(model_type) else: url += f"&types={model_type}" try: response = requests.get(url=url, headers={ 'x-api-key': api_key, 'username': username }) response.raise_for_status() json_response = response.json() if "checkpoints" not in json_response.keys(): return [] checkpoint_list = [] for ckpt in json_response["checkpoints"]: if "name" not in ckpt: continue if ckpt["name"] is None: continue if ckpt["type"] is None: continue ckpt_type = ckpt["type"] for ckpt_name in ckpt["name"]: ckpt_s3_pos = f"{ckpt['s3Location']}/{ckpt_name}" checkpoint_info[ckpt_type][ckpt_name] = ckpt_s3_pos checkpoint_list.append(ckpt_name) unique_list = list(set(checkpoint_list)) return unique_list except Exception as e: logger.error(f"Error fetching model list: {e}") return [] def get_checkpoints_by_type(model_type): url = "checkpoints?status=Active" if isinstance(model_type, list): url += "&types=" + "&types=".join(model_type) else: url += f"&types={model_type}" try: response = server_request(url) response.raise_for_status() json_response = response.json() if "checkpoints" not in json_response.keys(): return [] checkpoint_dict = {} for ckpt in json_response["checkpoints"]: if "name" not in ckpt: continue if ckpt["name"] is None: continue create_time = ckpt['created'] created = datetime.fromtimestamp(create_time) for ckpt_name in ckpt["name"]: checkpoint = [ckpt_name, created] if ckpt_name not in checkpoint_dict: checkpoint_dict[ckpt_name] = checkpoint checkpoint_list = list(checkpoint_dict.values()) return checkpoint_list except Exception as e: logging.error(f"Error fetching checkpoints list: {e}") return [] def update_sd_checkpoints(username): model_type = ["Stable-diffusion"] return get_model_list_by_type(model_type) def get_texual_inversion_list(): model_type = "embeddings" return get_model_list_by_type(model_type) def get_lora_list(): model_type = "Lora" return get_model_list_by_type(model_type) def get_hypernetwork_list(): model_type = "hypernetworks" return get_model_list_by_type(model_type) def get_controlnet_model_list(): model_type = "ControlNet" return get_model_list_by_type(model_type) def refresh_all_models(username): api_gateway_url = get_variable_from_json('api_gateway_url') api_key = get_variable_from_json('api_token') if not has_config(): return [] try: for rp, name in zip(checkpoint_type, checkpoint_name): url = api_gateway_url + f"checkpoints?status=Active&types={rp}" response = requests.get(url=url, headers={ 'x-api-key': api_key, 'username': username, }) json_response = response.json() logger.debug(f"response url json for model {rp} is {json_response}") checkpoint_info[rp] = {} if "checkpoints" not in json_response.keys(): continue for ckpt in json_response["checkpoints"]: if "name" not in ckpt: continue if ckpt["name"] is None: continue ckpt_type = ckpt["type"] checkpoint_info[ckpt_type] = {} for ckpt_name in ckpt["name"]: ckpt_s3_pos = f"{ckpt['s3Location']}/{ckpt_name.split(os.sep)[-1]}" checkpoint_info[ckpt_type][ckpt_name] = ckpt_s3_pos except Exception as e: logger.error(f"Error refresh all models: {e}") def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path, controlnet_model_path, vae_path, pr: gradio.Request): if not has_config(): return "Please config api url and token", None, None, None, None, None, None log = "start upload model to s3:" local_paths = [sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path, controlnet_model_path, vae_path] # check parameters params_empty = True for local_path in local_paths: if local_path: params_empty = False break if params_empty: return "Please choose at least one model to upload.", None, None, None, None, None, None logger.info(f"Refresh checkpoints before upload to get rid of duplicate uploads...") refresh_all_models(pr.username) for lp, rp in zip(local_paths, checkpoint_type): if lp == "" or not lp: continue logger.debug(f"lp is {lp}") model_name = lp.split(os.sep)[-1] exist_model_list = list(checkpoint_info[rp].keys()) if model_name in exist_model_list: logger.info(f"!!!skip to upload duplicate model {model_name}") continue part_size = 1000 * 1024 * 1024 file_size = os.stat(lp) parts_number = math.ceil(file_size.st_size / part_size) logger.info(f'!!!!!!!!!!{file_size} {parts_number}') # local_tar_path = f'{model_name}.tar' local_tar_path = model_name payload = { "checkpoint_type": rp, "filenames": [{ "filename": local_tar_path, "parts_number": parts_number }], "params": { "message": "placeholder for chkpts upload test", "creator": pr.username } } api_gateway_url = get_variable_from_json('api_gateway_url') # Check if api_url ends with '/', if not append it if not api_gateway_url.endswith('/'): api_gateway_url += '/' api_key = get_variable_from_json('api_token') logger.info(f'!!!!!!api_gateway_url {api_gateway_url}') url = str(api_gateway_url) + "checkpoints" logger.debug(f"Post request for upload s3 presign url: {url}") response = requests.post(url=url, json=payload, headers={'x-api-key': api_key, "username": pr.username}) if response.status_code not in [201, 202]: logger.error(f"create_checkpoint: {response.json()}") return response.json()['message'], None, None, None, None, None, None try: json_response = response.json()['data'] logger.debug(f"Response json {json_response}") s3_base = json_response["checkpoint"]["s3_location"] checkpoint_id = json_response["checkpoint"]["id"] logger.debug(f"Upload to S3 {s3_base}") logger.debug(f"Checkpoint ID: {checkpoint_id}") s3_signed_urls_resp = json_response["s3PresignUrl"][local_tar_path] # Upload src model to S3. if rp != "embeddings": local_model_path_in_repo = os.sep.join(['models', rp, model_name]) else: local_model_path_in_repo = os.sep.join([rp, model_name]) logger.debug("Pack the model file.") cp(lp, local_model_path_in_repo, recursive=True) if rp == "Stable-diffusion": model_yaml_name = model_name.split('.')[0] + ".yaml" local_model_yaml_path = os.sep.join([*lp.split(os.sep)[:-1], model_yaml_name]) local_model_yaml_path_in_repo = os.sep.join(["models", rp, model_yaml_name]) if os.path.isfile(local_model_yaml_path): cp(local_model_yaml_path, local_model_yaml_path_in_repo, recursive=True) tar(mode='c', archive=local_tar_path, sfiles=[local_model_path_in_repo, local_model_yaml_path_in_repo], verbose=True) else: tar(mode='c', archive=local_tar_path, sfiles=[local_model_path_in_repo], verbose=True) else: tar(mode='c', archive=local_tar_path, sfiles=[local_model_path_in_repo], verbose=True) multiparts_tags = upload_multipart_files_to_s3_by_signed_url( local_tar_path, s3_signed_urls_resp, part_size ) logger.debug(f"multiparts_tags {multiparts_tags}") payload = { "status": "Active", "multi_parts_tags": {local_tar_path: multiparts_tags} } # Start creating model on cloud. response = requests.put(url=f"{url}/{checkpoint_id}", json=payload, headers={'x-api-key': api_key}) logger.debug(response) log = f"finish upload {local_tar_path} to {s3_base}" # os.system(f"rm {local_tar_path}") rm(local_tar_path, recursive=True) except Exception as e: logger.error(f"fail to upload model {lp}, error: {e}") refresh_all_models(pr.username) return str(e), None, None, None, None, None, None logger.debug(f"Refresh checkpoints after upload...") refresh_all_models(pr.username) return log, None, None, None, None, None, None def sagemaker_upload_model_s3_local(): log = "Start upload:" return log def check_url(url: str): url = url.replace('\n', '') return url.strip() def sagemaker_upload_model_s3_url(model_type: str, url_list: str, description: str, pr: gradio.Request): if not url_list: return "Please fill the url." model_type = modelTypeMap.get(model_type) if not model_type: return "Please choose the model type." if description: params_dict = { 'message': description } else: params_dict = {} params_dict['creator'] = pr.username url_list = url_list.split(',') modified_urls = [check_url(url) for url in url_list] unique_urls = list(set(modified_urls)) for url in unique_urls: url_pattern = r'(https?|ftp)://[^\s/$.?#].[^\s]*' if not re.match(f'^{url_pattern}$', url): return f"{url} is not a valid url." data = {'checkpoint_type': model_type, 'urls': unique_urls, 'params': params_dict} api.set_username(pr.username) response = api.create_checkpoint(data=data) return response.json()['message'] def generate_on_cloud(sagemaker_endpoint): logger.info(f"checkpoint_info {checkpoint_info}") logger.info(f"sagemaker endpoint {sagemaker_endpoint}") text = "failed to check endpoint" return plaintext_to_html(text) # create a global event loop and apply the patch to allow nested event loops in single thread loop = asyncio.get_event_loop() nest_asyncio.apply() MAX_RUNNING_LIMIT = 10 def async_loop_wrapper(f): global loop # check if there are any running or queued tasks inside the event loop if loop.is_running(): # Calculate the number of running tasks while len([task for task in asyncio.all_tasks(loop) if not task.done()]) > MAX_RUNNING_LIMIT: logger.debug(f'Waiting for {MAX_RUNNING_LIMIT} running tasks to complete') time.sleep(1) else: # check if loop is closed and create a new one if loop.is_closed(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) # log this event since it should never happen logger.debug('Event loop was closed, created a new one') # Add new task to the event loop result = loop.run_until_complete(f()) return result def process_result_by_inference_id(inference_id_or_data, endpoint_type): print(f"process_result_by_inference_id {inference_id_or_data} {endpoint_type}") if endpoint_type == 'Async': resp = get_inference_job(inference_id_or_data) inference_id = inference_id_or_data else: resp = inference_id_or_data inference_id = resp['InferenceJobId'] image_list = [] # Return an empty list if selected_value is None info_text = '' infotexts = f"Inference id is {inference_id}, please check all historical inference result in 'Inference Job' dropdown list" json_list = [] prompt_txt = '' if resp is None: logger.info(f"get_inference_job resp is null") return image_list, info_text, plaintext_to_html(infotexts), infotexts else: logger.debug(f"get_inference_job resp is {resp}") if 'taskType' not in resp: raise Exception(resp) if resp['taskType'] in ['txt2img', 'img2img', 'interrogate_clip', 'interrogate_deepbooru']: while resp and resp['status'] == "inprogress": time.sleep(1) resp = get_inference_job(inference_id) if resp is None: logger.info(f"get_inference_job resp is null.") return image_list, info_text, plaintext_to_html(infotexts), infotexts if resp['status'] == "failed": if 'sagemakerRaw' in resp: infotexts = f"Inference job {inference_id} is failed, error message: {resp['sagemakerRaw']}" return image_list, info_text, plaintext_to_html(infotexts), infotexts elif resp['status'] == "succeed": if resp['taskType'] in ['interrogate_clip', 'interrogate_deepbooru']: prompt_txt = resp['caption'] # return with default value, including image_list, info_text, infotexts return image_list, info_text, plaintext_to_html(infotexts), prompt_txt images = resp['img_presigned_urls'] inference_param_json_list = resp['output_presigned_urls'] # todo: these not need anymore if resp['taskType'] in ['txt2img', 'img2img']: image_list = download_images_to_pil(images) json_file = download_images_to_json(inference_param_json_list)[0] if json_file: info_text = json_file infotexts = f"Inference id is {inference_id}\n{get_infer_job_time(resp)}" + \ json.loads(info_text)["infotexts"][0] else: logger.debug(f"File {json_file} does not exist.") info_text = 'something wrong when trying to download the inference parameters' infotexts = info_text return image_list, info_text, plaintext_to_html(infotexts), infotexts else: logger.debug(f"inference job status is {resp['status']}") return image_list, info_text, plaintext_to_html(infotexts), infotexts else: return image_list, info_text, plaintext_to_html(infotexts), infotexts def update_prompt_with_embedding(selected_items, prompt, lora_and_hypernet_models_state): if MODEL_TYPE.EMBEDDING.value in lora_and_hypernet_models_state: return update_prompt_with_selected_model( selected_items, prompt, MODEL_TYPE.EMBEDDING, lora_and_hypernet_models_state[MODEL_TYPE.EMBEDDING.value] ) return update_prompt_with_selected_model(selected_items, prompt, MODEL_TYPE.EMBEDDING) def update_prompt_with_hypernetwork(selected_items, prompt): return update_prompt_with_selected_model(selected_items, prompt, MODEL_TYPE.HYPER_NETWORK) def update_prompt_with_lora(selected_items, prompt): return update_prompt_with_selected_model(selected_items, prompt, MODEL_TYPE.LORA) def update_prompt_with_selected_model(selected_value, original_prompt, type, state_value = None): """Update txt2img or img2img prompt with selected model name Args: selected_value (gr.Dropdown): the selected dropdown original_prompt (gr.Textbox): the original prompt before updating type: the model type, embedding|lora|hypernetwork Returns: gr.Textbox: The updated prompt """ def _remove_embedding_prompt(state_value, selected_value, prompt_txt): if state_value: for embedding in state_value: if embedding not in selected_value: prompt_txt = prompt_txt.replace(embedding.split(".")[0], "") return prompt_txt def _remove_prompt_by_regex(pattern, prompt_txt): matches = re.findall(pattern, prompt_txt) for match in matches: if match not in existed_item: prompt_txt = prompt_txt.replace(match, "") return prompt_txt logger.info(f"Selected value is {selected_value}, \ original prompt is {original_prompt}, \ type is {type}") prompt_txt = original_prompt existed_item = [] # Compose prompt for Embedding/Lora/Hypernetwork for item in selected_value: model_name = item.split(".")[0] if MODEL_TYPE.LORA == type: model_prompt = f"" elif MODEL_TYPE.HYPER_NETWORK == type: model_prompt = f"" elif MODEL_TYPE.EMBEDDING == type: model_prompt = model_name else: logger.warning(f"The type {type} is not supported, skip it") continue existed_item.append(model_prompt) if model_prompt not in original_prompt: if 0 == len(original_prompt.strip()): prompt_txt = model_prompt else: prompt_txt += f" {model_prompt}" # Remove Embedding/Lora/Hypernetwork string which is not selected pattern = "" if MODEL_TYPE.LORA == type: pattern = r"]*:1>" prompt_txt = _remove_prompt_by_regex(pattern, prompt_txt) elif MODEL_TYPE.HYPER_NETWORK == type: pattern = r"]*:1>" prompt_txt = _remove_prompt_by_regex(pattern, prompt_txt) elif MODEL_TYPE.EMBEDDING == type: prompt_txt = _remove_embedding_prompt(state_value, selected_value, prompt_txt) else: logger.warning(f"The type {type} is not supported, skip it") return prompt_txt return prompt_txt def fake_gan(selected_value, original_prompt): logger.debug(f"selected value is {selected_value}") logger.debug(f"original prompt is {original_prompt}") if selected_value and selected_value != None_Option_For_On_Cloud_Model: delimiter = "-->" parts = selected_value.split(delimiter) # Extract the InferenceJobId value inference_job_id = parts[3].strip() inference_job_status = parts[2].strip() inference_job_taskType = parts[1].strip() if inference_job_status == 'failed': job = get_inference_job(inference_job_id) return [], [], plaintext_to_html(f"inference is failed: {job['sagemakerRaw']}"), original_prompt if inference_job_status != 'succeed': return [], [], plaintext_to_html(f'inference is {inference_job_status}'), original_prompt if inference_job_taskType in ["txt2img", "img2img"]: prompt_txt = original_prompt # output directory mapping to task type job = get_inference_job(inference_job_id) images = job['img_presigned_urls'] inference_param_json_list = job['output_presigned_urls'] image_list = download_images_to_pil(images) images_to_json = download_images_to_json(inference_param_json_list) # maybe param json was deleted if len(images_to_json) == 0: json_file = "" else: json_file = images_to_json[0] if json_file: info_text = json_file infotexts = f"Inference id is {inference_job_id}\n{get_infer_job_time(job)}" + \ json.loads(info_text)["infotexts"][0] else: logger.debug(f"File {json_file} does not exist.") info_text = 'something wrong when trying to download the inference parameters' infotexts = info_text elif inference_job_taskType in ["interrogate_clip", "interrogate_deepbooru"]: job_status = get_inference_job(inference_job_id) logger.debug(job_status) caption = job_status['caption'] prompt_txt = caption image_list = [] # Return an empty list if selected_value is None json_list = [] info_text = '' infotexts = '' else: prompt_txt = original_prompt image_list = [] # Return an empty list if selected_value is None json_list = [] info_text = '' infotexts = '' return image_list, info_text, plaintext_to_html(infotexts), prompt_txt def get_strptime(time_str): time_str = time_str.replace("T", " ") return datetime.strptime(time_str, '%Y-%m-%d %H:%M:%S.%f') def get_infer_job_time(job): string_array = [] inference_type = "" if 'inference_type' in job: inference_type = job['inference_type'] + " " if 'createTime' in job and 'completeTime' in job: complete_time = get_strptime(job['completeTime']) create_time = get_strptime(job['createTime']) duration = complete_time - create_time duration = round(duration.total_seconds(), 2) string = f"End-to-end API Duration: {duration} seconds" start_time = get_strptime(job['startTime']) duration = complete_time - start_time duration = round(duration.total_seconds(), 2) string_array.append(f"{string} (in which {inference_type}Inference: {duration} seconds)") else: if 'startTime' in job and 'completeTime' in job: complete_time = get_strptime(job['completeTime']) start_time = get_strptime(job['startTime']) duration = complete_time - start_time duration = round(duration.total_seconds(), 2) string_array.append(f"{inference_type}Inference Time: {duration} seconds") if 'params' in job: if 'sagemaker_inference_endpoint_name' in job['params']: endpoint_name = job['params']['sagemaker_inference_endpoint_name'] infer_ep_name = f"Endpoint: {endpoint_name}" if 'sagemaker_inference_instance_type' in job['params']: instance_type = job['params']['sagemaker_inference_instance_type'] infer_ep_name += f" ({instance_type})" string_array.append(infer_ep_name) if len(string_array) == 0: return "" return "\n".join(string_array) + "\n" def delete_inference_job(selected_value): logger.debug(f"selected value is {selected_value}") if selected_value and selected_value != None_Option_For_Infer_Job: if selected_value == 'cancelled': return delimiter = "-->" parts = selected_value.split(delimiter) # Extract the InferenceJobId value inference_job_id = parts[3].strip() resp = api.delete_inferences(data={ "inference_id_list": [inference_job_id], }) if resp.status_code != 204: gr.Error(f"Error deleting inference: {resp.json()['message']}") gr.Info(f"{inference_job_id} deleted successfully") file_path = f"{os.getcwd()}/outputs/{inference_job_id}.md" if os.path.exists(file_path): os.remove(file_path) file_path = f"{os.getcwd()}/outputs/{inference_job_id}.html" if os.path.exists(file_path): os.remove(file_path) else: gr.Warning('Please select a inference job to delete') def init_refresh_resource_list_from_cloud(username): logger.debug(f"start refreshing resource list from cloud") if get_variable_from_json('api_gateway_url') is not None: if not cloud_auth_manager.enableAuth: refresh_all_models(username) else: logger.debug('auth enabled, not preload load any model') else: logger.debug(f"there is no api-gateway url and token in local file,") def on_txt_time_change(start_time, end_time): logger.debug(f"!!!!!!!!!on_txt_time_change!!!!!!{start_time},{end_time}") global start_time_picker_txt_value global end_time_picker_txt_value start_time_picker_txt_value = start_time end_time_picker_txt_value = end_time return query_inference_job_list(txt_task_type, txt_status, txt_endpoint, txt_checkpoint, "txt2img") def on_img_time_change(start_time, end_time): logger.debug(f"!!!!!!!!!on_img_time_change!!!!!!{start_time},{end_time}") global start_time_picker_img_value global end_time_picker_img_value start_time_picker_img_value = start_time end_time_picker_img_value = end_time return query_inference_job_list(img_task_type, img_status, img_endpoint, img_checkpoint, "img2img") def load_inference_job_list(target_task_type, username, first_load="first"): inference_jobs = [None_Option_For_On_Cloud_Model] inferences_jobs_list = api_manager.list_all_inference_jobs_on_cloud(target_task_type, username, first_load) temp_list = [] for obj in inferences_jobs_list: if obj.get('createTime') is None: complete_time = obj.get('startTime') else: complete_time = obj.get('createTime') status = obj.get('status') task_type = obj.get('taskType', 'txt2img') inference_job_id = obj.get('InferenceJobId') # Compatible with lower versions of APIs without type filters if target_task_type == task_type: temp_list.append((complete_time, f"{complete_time}-->{task_type}-->{status}-->{inference_job_id}")) # Append the sorted combined strings to the txt2img_inference_job_ids list for item in temp_list: inference_jobs.append(item[1]) return inference_jobs def load_model_list(username): models_on_cloud = [] if 'sd_model_checkpoint' in opts.quicksettings_list: models_on_cloud = [None_Option_For_On_Cloud_Model] models_on_cloud += list(set([model['name'] for model in api_manager.list_models_on_cloud(username)])) return models_on_cloud def load_lora_models(username): data = list(set([model['name'] for model in api_manager.list_models_on_cloud(username, types='Lora')])) data.sort() return data def load_hypernetworks_models(username): return list(set([model['name'] for model in api_manager.list_models_on_cloud(username, types='hypernetworks')])) def load_vae_list(username): vae_model_on_cloud = ['Automatic', 'None'] vae_model_on_cloud += list(set([model['name'] for model in api_manager.list_models_on_cloud(username, types='VAE')])) return vae_model_on_cloud def load_controlnet_list(username): controlnet_model_on_cloud = ['None'] controlnet_model_on_cloud += list( set([model['name'] for model in api_manager.list_models_on_cloud(username, types='ControlNet')])) return controlnet_model_on_cloud def load_xyz_controlnet_list(username): controlnet_model_on_cloud = ['None'] controlnet_model_on_cloud += list( set([os.path.splitext(model['name'])[0] for model in api_manager.list_models_on_cloud(username, types='ControlNet')])) return controlnet_model_on_cloud def load_embeddings_list(username): embedding_model_on_cloud = list( set([model['name'] for model in api_manager.list_models_on_cloud(username, types='embeddings')])) return embedding_model_on_cloud def create_ui(is_img2img): global txt2img_gallery, txt2img_generation_info import modules.ui # init_refresh_resource_list_from_cloud() with gr.Blocks() as sagemaker_inference_tab: gr.HTML('

Amazon SageMaker Inference

') sagemaker_html_log = gr.HTML(elem_id=f'html_log_sagemaker') inference_task_type = 'txt2img' if not is_img2img else 'img2img' with gr.Column(): with gr.Row(): global lora_and_hypernet_models_state lora_and_hypernet_models_state = gr.State({}) sd_model_on_cloud_dropdown = gr.Dropdown(choices=[], value=None_Option_For_On_Cloud_Model, label='Stable Diffusion Checkpoint Used on Cloud') create_refresh_button_by_user(sd_model_on_cloud_dropdown, lambda *args: None, lambda username: { 'choices': load_model_list(username) }, 'refresh_cloud_model_down') infer_endpoint_dropdown = gr.Dropdown(choices=["Async", "Real-time"], value="Async", label='Inference Endpoint Type') with gr.Row(): with gr.Column(): with gr.Row(): sd_vae_on_cloud_dropdown = gr.Dropdown(choices=[], value='Automatic', label='SD Vae on Cloud') create_refresh_button_by_user(sd_vae_on_cloud_dropdown, lambda *args: None, lambda username: { 'choices': load_vae_list(username) }, 'refresh_cloud_vae_down') with gr.Column(): with gr.Row(): # Lora model global lora_dropdown lora_dropdown_local = gr.Dropdown(choices=[], label="Lora model on cloud", multiselect=True) create_refresh_button_by_user(lora_dropdown_local, lambda *args: None, lambda username: { 'choices': load_lora_models(username) }, 'refresh_lora_dropdown') lora_dropdown = lora_dropdown_local with gr.Row(): with gr.Column(): with gr.Row(): # Embedding model global embedding_dropdown embedding_dropdown_local = gr.Dropdown(choices=[], label="Embedding on cloud", multiselect=True) create_refresh_button_by_user(embedding_dropdown_local, lambda *args: None, lambda username: { 'choices': load_embeddings_list(username) }, 'refresh_embedding_dropdown') embedding_dropdown = embedding_dropdown_local with gr.Column(): with gr.Row(): # Hypernetwork model global hypernet_dropdown hypernet_dropdown_local = gr.Dropdown(choices=[], label="Hypernetwork on cloud", multiselect=True) create_refresh_button_by_user(hypernet_dropdown_local, lambda *args: None, lambda username: { 'choices': load_hypernetworks_models(username) }, 'refresh_hypernet_dropdown') hypernet_dropdown = hypernet_dropdown_local with gr.Row(visible=is_img2img): gr.HTML('
') with gr.Row(): global inference_job_dropdown inference_job_dropdown = gr.Dropdown(choices=[], value=None_Option_For_Infer_Job, label="Inference Job Histories: Time-Type-Status-UUID") create_refresh_button_by_user(inference_job_dropdown, lambda *args: None, lambda username: { 'choices': load_inference_job_list(inference_task_type, username, "first") }, 'refresh_inference_job_down') create_refresh_button_by_user(inference_job_dropdown, lambda *args: None, lambda username: { 'choices': load_inference_job_list(inference_task_type, username, "previous") }, 'refresh_inference_job_down_previous', '←') create_refresh_button_by_user(inference_job_dropdown, lambda *args: None, lambda username: { 'choices': load_inference_job_list(inference_task_type, username, "next") }, 'refresh_inference_job_down_next', '→') delete_inference_job_button = ToolButton(value='\U0001F5D1', elem_id="delete_inference_job") delete_inference_job_button.click( _js="delete_inference_job_confirm", fn=delete_inference_job, inputs=[inference_job_dropdown], outputs=[] ) api_inference_job_button = ToolButton(value='API', elem_id="api_inference_job") api_inference_job_cwd = ToolButton(value=os.getcwd(), elem_id="api_inference_job_path", visible=False) api_inference_job_button.click( _js="download_inference_job_api_call", fn=None, inputs=[api_inference_job_cwd, inference_job_dropdown], outputs=[] ) with gr.Row(): def setup_inference_for_plugin(pr: gr.Request): models_on_cloud = load_model_list(pr.username) vae_model_on_cloud = load_vae_list(pr.username) lora_models_on_cloud = load_lora_models(username=pr.username) hypernetworks_models_on_cloud = load_hypernetworks_models(pr.username) embedding_model_on_cloud = load_embeddings_list(pr.username) controlnet_list = load_controlnet_list(pr.username) controlnet_xyz_list = load_xyz_controlnet_list(pr.username) inference_jobs = load_inference_job_list(inference_task_type, pr.username) lora_hypernets = { "lora": lora_models_on_cloud, "hypernet": hypernetworks_models_on_cloud, "controlnet": controlnet_list, "controlnet_xyz": controlnet_xyz_list, "vae": vae_model_on_cloud, "sd": models_on_cloud, "embedding": embedding_model_on_cloud } return lora_hypernets, \ gr.update(choices=models_on_cloud, value=models_on_cloud[0] if models_on_cloud and len(models_on_cloud) > 0 else None_Option_For_On_Cloud_Model), \ gr.update(choices=inference_jobs), \ gr.update(choices=vae_model_on_cloud), \ gr.update(choices=lora_models_on_cloud), \ gr.update(choices=hypernetworks_models_on_cloud), \ gr.update(choices=embedding_model_on_cloud) sagemaker_inference_tab.load(fn=setup_inference_for_plugin, inputs=[], outputs=[ lora_and_hypernet_models_state, sd_model_on_cloud_dropdown, inference_job_dropdown, sd_vae_on_cloud_dropdown, lora_dropdown_local, hypernet_dropdown_local, embedding_dropdown_local ]) with gr.Group(): with gr.Accordion("Open for Checkpoint Merge in the Cloud!", visible=False, open=False): sagemaker_html_log = gr.HTML(elem_id=f'html_log_sagemaker') with FormRow(elem_id="modelmerger_models_in_the_cloud"): global primary_model_name primary_model_name = gr.Dropdown(elem_id="modelmerger_primary_model_name_in_the_cloud", label="Primary model (A) in the cloud") create_refresh_button_by_user(primary_model_name, lambda *args: None, lambda username: {"choices": sorted(update_sd_checkpoints(username))}, "refresh_checkpoint_A_in_the_cloud") global secondary_model_name secondary_model_name = gr.Dropdown(elem_id="modelmerger_secondary_model_name_in_the_cloud", label="Secondary model (B) in the cloud") create_refresh_button_by_user(secondary_model_name, lambda *args: None, lambda username: {"choices": sorted(update_sd_checkpoints(username))}, "refresh_checkpoint_B_in_the_cloud") global tertiary_model_name tertiary_model_name = gr.Dropdown(elem_id="modelmerger_tertiary_model_name_in_the_cloud", label="Tertiary model (C) in the cloud") create_refresh_button_by_user(tertiary_model_name, lambda *args: None, lambda username: {"choices": sorted(update_sd_checkpoints(username))}, "refresh_checkpoint_C_in_the_cloud") with gr.Row(): global modelmerger_merge_on_cloud modelmerger_merge_on_cloud = gr.Button(elem_id="modelmerger_merge_in_the_cloud", value="Merge on Cloud", variant='primary') return sd_model_on_cloud_dropdown, infer_endpoint_dropdown, sd_vae_on_cloud_dropdown, inference_job_dropdown, primary_model_name, secondary_model_name, tertiary_model_name, modelmerger_merge_on_cloud, lora_and_hypernet_models_state