1268 lines
52 KiB
Python
1268 lines
52 KiB
Python
import logging
|
||
import os
|
||
import html
|
||
import boto3
|
||
import time
|
||
import json
|
||
import gradio
|
||
import requests
|
||
import gradio as gr
|
||
|
||
from aws_extension.cloud_api_manager.api_logger import ApiLogger
|
||
from aws_extension.constant import MODEL_TYPE
|
||
|
||
import utils
|
||
from aws_extension.auth_service.simple_cloud_auth import cloud_auth_manager
|
||
from aws_extension.cloud_api_manager.api_manager import api_manager
|
||
from aws_extension.cloud_api_manager.api import api
|
||
from aws_extension.sagemaker_ui_utils import create_refresh_button_by_user
|
||
from modules.shared import opts
|
||
from modules.ui_components import FormRow
|
||
from utils import get_variable_from_json, upload_multipart_files_to_s3_by_signed_url, has_config
|
||
from requests.exceptions import JSONDecodeError
|
||
from datetime import datetime
|
||
import math
|
||
import re
|
||
from modules.ui_components import ToolButton
|
||
import asyncio
|
||
import nest_asyncio
|
||
|
||
from utils import cp, tar, rm
|
||
|
||
logger = logging.getLogger(__name__)
|
||
logger.setLevel(utils.LOGGING_LEVEL)
|
||
|
||
None_Option_For_On_Cloud_Model = "don't use on cloud inference"
|
||
None_Option_For_Infer_Job = "No Selected"
|
||
|
||
inference_job_dropdown = None
|
||
embedding_dropdown = None
|
||
hypernet_dropdown = None
|
||
lora_dropdown = None
|
||
# sagemaker_endpoint = None
|
||
modelmerger_merge_on_cloud = None
|
||
interrogate_clip_on_cloud_button = None
|
||
interrogate_deep_booru_on_cloud_button = None
|
||
|
||
primary_model_name = None
|
||
secondary_model_name = None
|
||
tertiary_model_name = None
|
||
|
||
# TODO: convert to dynamically init the following variables
|
||
txt2img_inference_job_ids = []
|
||
|
||
sd_checkpoints = []
|
||
textual_inversion_list = []
|
||
lora_list = []
|
||
hyperNetwork_list = []
|
||
ControlNet_model_list = []
|
||
|
||
# Initial checkpoints information
|
||
checkpoint_info = {}
|
||
checkpoint_type = ["Stable-diffusion", "embeddings", "Lora", "hypernetworks", "ControlNet", "VAE"]
|
||
checkpoint_name = ["stable_diffusion", "embeddings", "lora", "hypernetworks", "controlnet", "VAE"]
|
||
stable_diffusion_list = []
|
||
embeddings_list = []
|
||
|
||
hypernetworks_list = []
|
||
controlnet_list = []
|
||
for ckpt_type, ckpt_name in zip(checkpoint_type, checkpoint_name):
|
||
checkpoint_info[ckpt_type] = {}
|
||
# get api_gateway_url
|
||
api_gateway_url = get_variable_from_json('api_gateway_url')
|
||
api_key = get_variable_from_json('api_token')
|
||
|
||
start_time_picker_img_value = None
|
||
end_time_picker_img_value = None
|
||
start_time_picker_txt_value = None
|
||
end_time_picker_txt_value = None
|
||
|
||
txt_task_type = None
|
||
txt_status = None
|
||
txt_endpoint = None
|
||
txt_checkpoint = None
|
||
|
||
img_task_type = None
|
||
img_status = None
|
||
img_endpoint = None
|
||
img_checkpoint = None
|
||
|
||
show_all_inference_job = False
|
||
|
||
modelTypeMap = {
|
||
'SD Checkpoints': 'Stable-diffusion',
|
||
'Textual Inversion': 'embeddings',
|
||
'LoRA model': 'Lora',
|
||
'ControlNet model': 'ControlNet',
|
||
'Hypernetwork': 'hypernetworks',
|
||
'VAE': 'VAE'
|
||
}
|
||
|
||
|
||
def plaintext_to_html(text):
|
||
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
|
||
return text
|
||
|
||
|
||
def server_request_get(path, params):
|
||
api_gateway_url = get_variable_from_json('api_gateway_url')
|
||
# Check if api_url ends with '/', if not append it
|
||
if not api_gateway_url.endswith('/'):
|
||
api_gateway_url += '/'
|
||
api_key = get_variable_from_json('api_token')
|
||
headers = {
|
||
"x-api-key": api_key,
|
||
"Content-Type": "application/json"
|
||
}
|
||
list_endpoint_url = f'{api_gateway_url}{path}'
|
||
response = requests.get(list_endpoint_url, params=params, headers=headers)
|
||
return response
|
||
|
||
|
||
def get_s3_file_names(bucket, folder):
|
||
"""Get a list of file names from an S3 bucket and folder."""
|
||
s3 = boto3.resource('s3')
|
||
bucket = s3.Bucket(bucket)
|
||
objects = bucket.objects.filter(Prefix=folder)
|
||
names = [obj.key for obj in objects]
|
||
return names
|
||
|
||
|
||
def get_current_date():
|
||
today = datetime.today()
|
||
formatted_date = today.strftime('%Y-%m-%d')
|
||
return formatted_date
|
||
|
||
|
||
def server_request(path):
|
||
api_gateway_url = get_variable_from_json('api_gateway_url')
|
||
if not has_config():
|
||
return []
|
||
# Check if api_url ends with '/', if not append it
|
||
if not api_gateway_url.endswith('/'):
|
||
api_gateway_url += '/'
|
||
api_key = get_variable_from_json('api_token')
|
||
headers = {
|
||
"x-api-key": api_key,
|
||
"Content-Type": "application/json"
|
||
}
|
||
list_endpoint_url = f'{api_gateway_url}{path}'
|
||
response = requests.get(list_endpoint_url, headers=headers)
|
||
return response
|
||
|
||
|
||
def datetime_to_short_form(datetime_str):
|
||
dt = get_strptime(datetime_str)
|
||
short_form = dt.strftime("%Y-%m-%d-%H-%M-%S")
|
||
return short_form
|
||
|
||
|
||
def query_page_inference_job_list(task_type: str, status: str, endpoint: str, checkpoint: str, is_img2img: bool,
|
||
show_all: bool):
|
||
global show_all_inference_job
|
||
show_all_inference_job = show_all
|
||
if is_img2img:
|
||
global img_task_type
|
||
img_task_type = task_type
|
||
global img_status
|
||
img_status = status
|
||
global img_endpoint
|
||
img_endpoint = endpoint
|
||
global img_checkpoint
|
||
img_checkpoint = checkpoint
|
||
opt_type = 'img2img'
|
||
logger.debug(f"{opt_type} {img_task_type} {img_status} {img_endpoint} {img_checkpoint} {show_all}")
|
||
return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type)
|
||
else:
|
||
global txt_task_type
|
||
txt_task_type = task_type
|
||
global txt_status
|
||
txt_status = txt_status
|
||
global txt_endpoint
|
||
txt_endpoint = endpoint
|
||
global txt_checkpoint
|
||
txt_checkpoint = checkpoint
|
||
opt_type = 'txt2img'
|
||
logger.debug(f"{opt_type} {txt_task_type} {txt_status} {txt_endpoint} {txt_checkpoint} {show_all}")
|
||
return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type)
|
||
|
||
|
||
def query_img_inference_job_list(task_type: str, status: str, endpoint: str, checkpoint: str):
|
||
opt_type = 'img2img'
|
||
global img_task_type
|
||
img_task_type = task_type
|
||
global img_status
|
||
img_status = status
|
||
global img_endpoint
|
||
img_endpoint = endpoint
|
||
global img_checkpoint
|
||
img_checkpoint = checkpoint
|
||
return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type)
|
||
|
||
|
||
def query_txt_inference_job_list(task_type: str, status: str, endpoint: str, checkpoint: str):
|
||
opt_type = 'txt2img'
|
||
global txt_task_type
|
||
txt_task_type = task_type
|
||
global txt_status
|
||
txt_status = status
|
||
global txt_endpoint
|
||
txt_endpoint = endpoint
|
||
global txt_checkpoint
|
||
txt_checkpoint = checkpoint
|
||
|
||
return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type)
|
||
|
||
|
||
def query_inference_job_list(task_type: str = '', status: str = '',
|
||
endpoint: str = '', checkpoint: str = '', opt_type: str = ''):
|
||
logger.debug(
|
||
f"query_inference_job_list start!!{status},{task_type},{endpoint},{checkpoint},{start_time_picker_txt_value},{end_time_picker_txt_value} {show_all_inference_job}")
|
||
try:
|
||
body_params = {}
|
||
if status:
|
||
body_params['status'] = status
|
||
if task_type:
|
||
body_params['task_type'] = task_type
|
||
if opt_type == 'txt2img':
|
||
if start_time_picker_txt_value:
|
||
body_params['start_time'] = start_time_picker_txt_value
|
||
if end_time_picker_txt_value:
|
||
body_params['end_time'] = end_time_picker_txt_value
|
||
elif opt_type == 'img2img':
|
||
if start_time_picker_img_value:
|
||
body_params['start_time'] = start_time_picker_img_value
|
||
if end_time_picker_img_value:
|
||
body_params['end_time'] = end_time_picker_img_value
|
||
if endpoint:
|
||
endpoint_name_array = endpoint.split("+")
|
||
if len(endpoint_name_array) > 0:
|
||
body_params['endpoint'] = endpoint_name_array[0]
|
||
if checkpoint:
|
||
body_params['checkpoint'] = checkpoint
|
||
body_params['limit'] = -1 if show_all_inference_job else 10
|
||
response = server_request_get(f'inferences', body_params)
|
||
r = response.json()['data']['inferences']
|
||
logger.debug(r)
|
||
if r:
|
||
txt2img_inference_job_ids.clear() # Clear the existing list before appending new values
|
||
temp_list = []
|
||
for obj in r:
|
||
if obj.get('completeTime') is None:
|
||
complete_time = obj.get('startTime')
|
||
else:
|
||
complete_time = obj.get('completeTime')
|
||
status = obj.get('status')
|
||
task_type = obj.get('taskType', 'txt2img')
|
||
inference_job_id = obj.get('InferenceJobId')
|
||
combined_string = f"{complete_time}-->{task_type}-->{status}-->{inference_job_id}"
|
||
temp_list.append((complete_time, combined_string))
|
||
# Sort the list based on completeTime in ascending order
|
||
sorted_list = sorted(temp_list, key=lambda x: x[0], reverse=False)
|
||
# Append the sorted combined strings to the txt2img_inference_job_ids list
|
||
for item in sorted_list:
|
||
txt2img_inference_job_ids.append(item[1])
|
||
# inference_job_dropdown.update(choices=txt2img_inference_job_ids)
|
||
return gr.Dropdown.update(choices=txt2img_inference_job_ids)
|
||
else:
|
||
logger.info("The API response is empty.")
|
||
return gr.Dropdown.update(choices=[])
|
||
except Exception as e:
|
||
logger.error("Exception occurred when fetching inference_job_ids")
|
||
logger.error(e)
|
||
return gr.Dropdown.update(choices=[])
|
||
|
||
|
||
def get_inference_job(inference_job_id):
|
||
url = f'inferences/{inference_job_id}'
|
||
response = server_request(url)
|
||
logger.debug(f"get_inference_job response {response}")
|
||
infer_id = ""
|
||
if 'data' in response.json():
|
||
infer_id = response.json()['data']['InferenceJobId']
|
||
api_logger = ApiLogger(
|
||
action='inference',
|
||
append=True,
|
||
infer_id=infer_id
|
||
)
|
||
headers = {
|
||
"x-api-key": get_variable_from_json('api_token'),
|
||
"Content-Type": "application/json"
|
||
}
|
||
api_gateway_url = get_variable_from_json('api_gateway_url')
|
||
api_logger.req_log(sub_action="GetInferenceJob",
|
||
method='GET',
|
||
path=f"{api_gateway_url}{url}",
|
||
headers=headers,
|
||
response=response,
|
||
desc=f"Get inference job detail from cloud by ID ({inference_job_id}), "
|
||
f"end request if data.status == succeed, "
|
||
f"ID from previous step: CreateInference -> data -> inference -> id")
|
||
if 'data' not in response.json():
|
||
raise Exception(response.json())
|
||
return response.json()['data']
|
||
|
||
|
||
def download_images(image_urls: list, local_directory: str):
|
||
if not os.path.exists(local_directory):
|
||
os.makedirs(local_directory)
|
||
|
||
image_list = []
|
||
for url in image_urls:
|
||
try:
|
||
response = requests.get(url)
|
||
response.raise_for_status()
|
||
image_name = os.path.basename(url).split('?')[0]
|
||
local_path = os.path.join(local_directory, image_name)
|
||
|
||
with open(local_path, 'wb') as f:
|
||
f.write(response.content)
|
||
image_list.append(local_path)
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"Error downloading image {url}: {e}")
|
||
return image_list
|
||
|
||
|
||
def download_images_to_json(image_urls: list):
|
||
results = []
|
||
for url in image_urls:
|
||
try:
|
||
response = requests.get(url)
|
||
response.raise_for_status()
|
||
json_resp = response.json()
|
||
results.append(json_resp['info'])
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"Error downloading image {url}: {e}")
|
||
return results
|
||
|
||
|
||
def download_images_to_pil(image_urls: list):
|
||
image_list = []
|
||
for url in image_urls:
|
||
try:
|
||
response = requests.get(url)
|
||
response.raise_for_status()
|
||
from PIL import Image
|
||
import io
|
||
pil_image = Image.open(io.BytesIO(response.content))
|
||
image_list.append(pil_image)
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"Error downloading image {url}: {e}")
|
||
return image_list
|
||
|
||
|
||
def get_model_list_by_type(model_type, username=""):
|
||
api_gateway_url = get_variable_from_json('api_gateway_url')
|
||
api_key = get_variable_from_json('api_token')
|
||
|
||
# check if api_gateway_url and api_key are set
|
||
if not has_config():
|
||
logger.info("api_gateway_url or api_key is not set")
|
||
return []
|
||
|
||
# Check if api_url ends with '/', if not append it
|
||
if not api_gateway_url.endswith('/'):
|
||
api_gateway_url += '/'
|
||
|
||
url = api_gateway_url + f"checkpoints?status=Active"
|
||
if isinstance(model_type, list):
|
||
url += "&types=" + "&types=".join(model_type)
|
||
else:
|
||
url += f"&types={model_type}"
|
||
|
||
try:
|
||
response = requests.get(url=url, headers={
|
||
'x-api-key': api_key,
|
||
'username': username
|
||
})
|
||
response.raise_for_status()
|
||
json_response = response.json()
|
||
|
||
if "checkpoints" not in json_response.keys():
|
||
return []
|
||
|
||
checkpoint_list = []
|
||
for ckpt in json_response["checkpoints"]:
|
||
if "name" not in ckpt:
|
||
continue
|
||
if ckpt["name"] is None:
|
||
continue
|
||
if ckpt["type"] is None:
|
||
continue
|
||
ckpt_type = ckpt["type"]
|
||
for ckpt_name in ckpt["name"]:
|
||
ckpt_s3_pos = f"{ckpt['s3Location']}/{ckpt_name}"
|
||
checkpoint_info[ckpt_type][ckpt_name] = ckpt_s3_pos
|
||
checkpoint_list.append(ckpt_name)
|
||
|
||
unique_list = list(set(checkpoint_list))
|
||
return unique_list
|
||
except Exception as e:
|
||
logger.error(f"Error fetching model list: {e}")
|
||
return []
|
||
|
||
|
||
def get_checkpoints_by_type(model_type):
|
||
url = "checkpoints?status=Active"
|
||
if isinstance(model_type, list):
|
||
url += "&types=" + "&types=".join(model_type)
|
||
else:
|
||
url += f"&types={model_type}"
|
||
try:
|
||
response = server_request(url)
|
||
response.raise_for_status()
|
||
json_response = response.json()
|
||
if "checkpoints" not in json_response.keys():
|
||
return []
|
||
checkpoint_dict = {}
|
||
for ckpt in json_response["checkpoints"]:
|
||
if "name" not in ckpt:
|
||
continue
|
||
if ckpt["name"] is None:
|
||
continue
|
||
create_time = ckpt['created']
|
||
created = datetime.fromtimestamp(create_time)
|
||
for ckpt_name in ckpt["name"]:
|
||
checkpoint = [ckpt_name, created]
|
||
if ckpt_name not in checkpoint_dict:
|
||
checkpoint_dict[ckpt_name] = checkpoint
|
||
checkpoint_list = list(checkpoint_dict.values())
|
||
return checkpoint_list
|
||
except Exception as e:
|
||
logging.error(f"Error fetching checkpoints list: {e}")
|
||
return []
|
||
|
||
|
||
def update_sd_checkpoints(username):
|
||
model_type = ["Stable-diffusion"]
|
||
return get_model_list_by_type(model_type)
|
||
|
||
|
||
def get_texual_inversion_list():
|
||
model_type = "embeddings"
|
||
return get_model_list_by_type(model_type)
|
||
|
||
|
||
def get_lora_list():
|
||
model_type = "Lora"
|
||
return get_model_list_by_type(model_type)
|
||
|
||
|
||
def get_hypernetwork_list():
|
||
model_type = "hypernetworks"
|
||
return get_model_list_by_type(model_type)
|
||
|
||
|
||
def get_controlnet_model_list():
|
||
model_type = "ControlNet"
|
||
return get_model_list_by_type(model_type)
|
||
|
||
|
||
def refresh_all_models(username):
|
||
api_gateway_url = get_variable_from_json('api_gateway_url')
|
||
api_key = get_variable_from_json('api_token')
|
||
|
||
if not has_config():
|
||
return []
|
||
|
||
try:
|
||
for rp, name in zip(checkpoint_type, checkpoint_name):
|
||
url = api_gateway_url + f"checkpoints?status=Active&types={rp}"
|
||
response = requests.get(url=url, headers={
|
||
'x-api-key': api_key,
|
||
'username': username,
|
||
})
|
||
json_response = response.json()
|
||
logger.debug(f"response url json for model {rp} is {json_response}")
|
||
checkpoint_info[rp] = {}
|
||
if "checkpoints" not in json_response.keys():
|
||
continue
|
||
for ckpt in json_response["checkpoints"]:
|
||
if "name" not in ckpt:
|
||
continue
|
||
if ckpt["name"] is None:
|
||
continue
|
||
ckpt_type = ckpt["type"]
|
||
checkpoint_info[ckpt_type] = {}
|
||
for ckpt_name in ckpt["name"]:
|
||
ckpt_s3_pos = f"{ckpt['s3Location']}/{ckpt_name.split(os.sep)[-1]}"
|
||
checkpoint_info[ckpt_type][ckpt_name] = ckpt_s3_pos
|
||
except Exception as e:
|
||
logger.error(f"Error refresh all models: {e}")
|
||
|
||
|
||
def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path,
|
||
controlnet_model_path, vae_path, pr: gradio.Request):
|
||
|
||
if not has_config():
|
||
return "Please config api url and token", None, None, None, None, None, None
|
||
|
||
log = "start upload model to s3:"
|
||
|
||
local_paths = [sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path, controlnet_model_path,
|
||
vae_path]
|
||
|
||
# check parameters
|
||
params_empty = True
|
||
for local_path in local_paths:
|
||
if local_path:
|
||
params_empty = False
|
||
break
|
||
if params_empty:
|
||
return "Please choose at least one model to upload.", None, None, None, None, None, None
|
||
|
||
logger.info(f"Refresh checkpoints before upload to get rid of duplicate uploads...")
|
||
refresh_all_models(pr.username)
|
||
|
||
for lp, rp in zip(local_paths, checkpoint_type):
|
||
if lp == "" or not lp:
|
||
continue
|
||
logger.debug(f"lp is {lp}")
|
||
model_name = lp.split(os.sep)[-1]
|
||
|
||
exist_model_list = list(checkpoint_info[rp].keys())
|
||
|
||
if model_name in exist_model_list:
|
||
logger.info(f"!!!skip to upload duplicate model {model_name}")
|
||
continue
|
||
|
||
part_size = 1000 * 1024 * 1024
|
||
file_size = os.stat(lp)
|
||
parts_number = math.ceil(file_size.st_size / part_size)
|
||
logger.info(f'!!!!!!!!!!{file_size} {parts_number}')
|
||
|
||
# local_tar_path = f'{model_name}.tar'
|
||
local_tar_path = model_name
|
||
payload = {
|
||
"checkpoint_type": rp,
|
||
"filenames": [{
|
||
"filename": local_tar_path,
|
||
"parts_number": parts_number
|
||
}],
|
||
"params": {
|
||
"message": "placeholder for chkpts upload test",
|
||
"creator": pr.username
|
||
}
|
||
}
|
||
api_gateway_url = get_variable_from_json('api_gateway_url')
|
||
# Check if api_url ends with '/', if not append it
|
||
if not api_gateway_url.endswith('/'):
|
||
api_gateway_url += '/'
|
||
api_key = get_variable_from_json('api_token')
|
||
logger.info(f'!!!!!!api_gateway_url {api_gateway_url}')
|
||
|
||
url = str(api_gateway_url) + "checkpoints"
|
||
|
||
logger.debug(f"Post request for upload s3 presign url: {url}")
|
||
|
||
response = requests.post(url=url, json=payload, headers={'x-api-key': api_key, "username": pr.username})
|
||
|
||
if response.status_code not in [201, 202]:
|
||
logger.error(f"create_checkpoint: {response.json()}")
|
||
return response.json()['message'], None, None, None, None, None, None
|
||
|
||
try:
|
||
json_response = response.json()['data']
|
||
logger.debug(f"Response json {json_response}")
|
||
s3_base = json_response["checkpoint"]["s3_location"]
|
||
checkpoint_id = json_response["checkpoint"]["id"]
|
||
logger.debug(f"Upload to S3 {s3_base}")
|
||
logger.debug(f"Checkpoint ID: {checkpoint_id}")
|
||
|
||
s3_signed_urls_resp = json_response["s3PresignUrl"][local_tar_path]
|
||
# Upload src model to S3.
|
||
if rp != "embeddings":
|
||
local_model_path_in_repo = os.sep.join(['models', rp, model_name])
|
||
else:
|
||
local_model_path_in_repo = os.sep.join([rp, model_name])
|
||
logger.debug("Pack the model file.")
|
||
cp(lp, local_model_path_in_repo, recursive=True)
|
||
if rp == "Stable-diffusion":
|
||
model_yaml_name = model_name.split('.')[0] + ".yaml"
|
||
local_model_yaml_path = os.sep.join([*lp.split(os.sep)[:-1], model_yaml_name])
|
||
local_model_yaml_path_in_repo = os.sep.join(["models", rp, model_yaml_name])
|
||
if os.path.isfile(local_model_yaml_path):
|
||
cp(local_model_yaml_path, local_model_yaml_path_in_repo, recursive=True)
|
||
tar(mode='c', archive=local_tar_path,
|
||
sfiles=[local_model_path_in_repo, local_model_yaml_path_in_repo], verbose=True)
|
||
else:
|
||
tar(mode='c', archive=local_tar_path, sfiles=[local_model_path_in_repo], verbose=True)
|
||
else:
|
||
tar(mode='c', archive=local_tar_path, sfiles=[local_model_path_in_repo], verbose=True)
|
||
|
||
multiparts_tags = upload_multipart_files_to_s3_by_signed_url(
|
||
local_tar_path,
|
||
s3_signed_urls_resp,
|
||
part_size
|
||
)
|
||
logger.debug(f"multiparts_tags {multiparts_tags}")
|
||
|
||
payload = {
|
||
"status": "Active",
|
||
"multi_parts_tags": {local_tar_path: multiparts_tags}
|
||
}
|
||
# Start creating model on cloud.
|
||
response = requests.put(url=f"{url}/{checkpoint_id}", json=payload, headers={'x-api-key': api_key})
|
||
logger.debug(response)
|
||
|
||
log = f"finish upload {local_tar_path} to {s3_base}"
|
||
|
||
# os.system(f"rm {local_tar_path}")
|
||
rm(local_tar_path, recursive=True)
|
||
except Exception as e:
|
||
logger.error(f"fail to upload model {lp}, error: {e}")
|
||
refresh_all_models(pr.username)
|
||
return str(e), None, None, None, None, None, None
|
||
|
||
logger.debug(f"Refresh checkpoints after upload...")
|
||
refresh_all_models(pr.username)
|
||
return log, None, None, None, None, None, None
|
||
|
||
|
||
def sagemaker_upload_model_s3_local():
|
||
log = "Start upload:"
|
||
return log
|
||
|
||
|
||
def check_url(url: str):
|
||
url = url.replace('\n', '')
|
||
return url.strip()
|
||
|
||
|
||
def sagemaker_upload_model_s3_url(model_type: str, url_list: str, description: str, pr: gradio.Request):
|
||
if not url_list:
|
||
return "Please fill the url."
|
||
|
||
model_type = modelTypeMap.get(model_type)
|
||
if not model_type:
|
||
return "Please choose the model type."
|
||
|
||
if description:
|
||
params_dict = {
|
||
'message': description
|
||
}
|
||
else:
|
||
params_dict = {}
|
||
|
||
params_dict['creator'] = pr.username
|
||
|
||
url_list = url_list.split(',')
|
||
modified_urls = [check_url(url) for url in url_list]
|
||
unique_urls = list(set(modified_urls))
|
||
for url in unique_urls:
|
||
url_pattern = r'(https?|ftp)://[^\s/$.?#].[^\s]*'
|
||
if not re.match(f'^{url_pattern}$', url):
|
||
return f"{url} is not a valid url."
|
||
|
||
data = {'checkpoint_type': model_type, 'urls': unique_urls, 'params': params_dict}
|
||
api.set_username(pr.username)
|
||
response = api.create_checkpoint(data=data)
|
||
return response.json()['message']
|
||
|
||
|
||
def generate_on_cloud(sagemaker_endpoint):
|
||
logger.info(f"checkpoint_info {checkpoint_info}")
|
||
logger.info(f"sagemaker endpoint {sagemaker_endpoint}")
|
||
text = "failed to check endpoint"
|
||
return plaintext_to_html(text)
|
||
|
||
|
||
# create a global event loop and apply the patch to allow nested event loops in single thread
|
||
loop = asyncio.get_event_loop()
|
||
nest_asyncio.apply()
|
||
|
||
MAX_RUNNING_LIMIT = 10
|
||
|
||
|
||
def async_loop_wrapper(f):
|
||
global loop
|
||
# check if there are any running or queued tasks inside the event loop
|
||
if loop.is_running():
|
||
# Calculate the number of running tasks
|
||
while len([task for task in asyncio.all_tasks(loop) if not task.done()]) > MAX_RUNNING_LIMIT:
|
||
logger.debug(f'Waiting for {MAX_RUNNING_LIMIT} running tasks to complete')
|
||
time.sleep(1)
|
||
else:
|
||
# check if loop is closed and create a new one
|
||
if loop.is_closed():
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
# log this event since it should never happen
|
||
logger.debug('Event loop was closed, created a new one')
|
||
|
||
# Add new task to the event loop
|
||
result = loop.run_until_complete(f())
|
||
return result
|
||
|
||
|
||
def process_result_by_inference_id(inference_id_or_data, endpoint_type):
|
||
print(f"process_result_by_inference_id {inference_id_or_data} {endpoint_type}")
|
||
if endpoint_type == 'Async':
|
||
resp = get_inference_job(inference_id_or_data)
|
||
inference_id = inference_id_or_data
|
||
else:
|
||
resp = inference_id_or_data
|
||
inference_id = resp['InferenceJobId']
|
||
|
||
image_list = [] # Return an empty list if selected_value is None
|
||
info_text = ''
|
||
infotexts = f"Inference id is {inference_id}, please check all historical inference result in 'Inference Job' dropdown list"
|
||
json_list = []
|
||
prompt_txt = ''
|
||
|
||
if resp is None:
|
||
logger.info(f"get_inference_job resp is null")
|
||
return image_list, info_text, plaintext_to_html(infotexts), infotexts
|
||
else:
|
||
logger.debug(f"get_inference_job resp is {resp}")
|
||
if 'taskType' not in resp:
|
||
raise Exception(resp)
|
||
|
||
if resp['taskType'] in ['txt2img', 'img2img', 'interrogate_clip', 'interrogate_deepbooru']:
|
||
while resp and resp['status'] == "inprogress":
|
||
time.sleep(1)
|
||
resp = get_inference_job(inference_id)
|
||
if resp is None:
|
||
logger.info(f"get_inference_job resp is null.")
|
||
return image_list, info_text, plaintext_to_html(infotexts), infotexts
|
||
if resp['status'] == "failed":
|
||
if 'sagemakerRaw' in resp:
|
||
infotexts = f"Inference job {inference_id} is failed, error message: {resp['sagemakerRaw']}"
|
||
return image_list, info_text, plaintext_to_html(infotexts), infotexts
|
||
elif resp['status'] == "succeed":
|
||
if resp['taskType'] in ['interrogate_clip', 'interrogate_deepbooru']:
|
||
prompt_txt = resp['caption']
|
||
# return with default value, including image_list, info_text, infotexts
|
||
return image_list, info_text, plaintext_to_html(infotexts), prompt_txt
|
||
images = resp['img_presigned_urls']
|
||
inference_param_json_list = resp['output_presigned_urls']
|
||
# todo: these not need anymore
|
||
if resp['taskType'] in ['txt2img', 'img2img']:
|
||
image_list = download_images_to_pil(images)
|
||
json_file = download_images_to_json(inference_param_json_list)[0]
|
||
|
||
if json_file:
|
||
info_text = json_file
|
||
infotexts = f"Inference id is {inference_id}\n{get_infer_job_time(resp)}" + \
|
||
json.loads(info_text)["infotexts"][0]
|
||
else:
|
||
logger.debug(f"File {json_file} does not exist.")
|
||
info_text = 'something wrong when trying to download the inference parameters'
|
||
infotexts = info_text
|
||
return image_list, info_text, plaintext_to_html(infotexts), infotexts
|
||
else:
|
||
logger.debug(f"inference job status is {resp['status']}")
|
||
return image_list, info_text, plaintext_to_html(infotexts), infotexts
|
||
else:
|
||
return image_list, info_text, plaintext_to_html(infotexts), infotexts
|
||
|
||
|
||
def update_prompt_with_embedding(selected_items, prompt, lora_and_hypernet_models_state):
|
||
if MODEL_TYPE.EMBEDDING.value in lora_and_hypernet_models_state:
|
||
return update_prompt_with_selected_model(
|
||
selected_items,
|
||
prompt,
|
||
MODEL_TYPE.EMBEDDING,
|
||
lora_and_hypernet_models_state[MODEL_TYPE.EMBEDDING.value]
|
||
)
|
||
|
||
return update_prompt_with_selected_model(selected_items, prompt, MODEL_TYPE.EMBEDDING)
|
||
|
||
|
||
def update_prompt_with_hypernetwork(selected_items, prompt):
|
||
return update_prompt_with_selected_model(selected_items, prompt, MODEL_TYPE.HYPER_NETWORK)
|
||
|
||
|
||
def update_prompt_with_lora(selected_items, prompt):
|
||
return update_prompt_with_selected_model(selected_items, prompt, MODEL_TYPE.LORA)
|
||
|
||
|
||
def update_prompt_with_selected_model(selected_value, original_prompt, type, state_value = None):
|
||
"""Update txt2img or img2img prompt with selected model name
|
||
|
||
Args:
|
||
selected_value (gr.Dropdown): the selected dropdown
|
||
original_prompt (gr.Textbox): the original prompt before updating
|
||
type: the model type, embedding|lora|hypernetwork
|
||
|
||
Returns:
|
||
gr.Textbox: The updated prompt
|
||
"""
|
||
|
||
def _remove_embedding_prompt(state_value, selected_value, prompt_txt):
|
||
if state_value:
|
||
for embedding in state_value:
|
||
if embedding not in selected_value:
|
||
prompt_txt = prompt_txt.replace(embedding.split(".")[0], "")
|
||
|
||
return prompt_txt
|
||
|
||
def _remove_prompt_by_regex(pattern, prompt_txt):
|
||
matches = re.findall(pattern, prompt_txt)
|
||
for match in matches:
|
||
if match not in existed_item:
|
||
prompt_txt = prompt_txt.replace(match, "")
|
||
|
||
return prompt_txt
|
||
|
||
logger.info(f"Selected value is {selected_value}, \
|
||
original prompt is {original_prompt}, \
|
||
type is {type}")
|
||
prompt_txt = original_prompt
|
||
existed_item = []
|
||
|
||
# Compose prompt for Embedding/Lora/Hypernetwork
|
||
for item in selected_value:
|
||
model_name = item.split(".")[0]
|
||
if MODEL_TYPE.LORA == type:
|
||
model_prompt = f"<lora:{model_name}:1>"
|
||
elif MODEL_TYPE.HYPER_NETWORK == type:
|
||
model_prompt = f"<hypernet:{model_name}:1>"
|
||
elif MODEL_TYPE.EMBEDDING == type:
|
||
model_prompt = model_name
|
||
else:
|
||
logger.warning(f"The type {type} is not supported, skip it")
|
||
continue
|
||
|
||
existed_item.append(model_prompt)
|
||
if model_prompt not in original_prompt:
|
||
if 0 == len(original_prompt.strip()):
|
||
prompt_txt = model_prompt
|
||
else:
|
||
prompt_txt += f" {model_prompt}"
|
||
|
||
# Remove Embedding/Lora/Hypernetwork string which is not selected
|
||
pattern = ""
|
||
if MODEL_TYPE.LORA == type:
|
||
pattern = r"<lora:[^>]*:1>"
|
||
prompt_txt = _remove_prompt_by_regex(pattern, prompt_txt)
|
||
elif MODEL_TYPE.HYPER_NETWORK == type:
|
||
pattern = r"<hypernet:[^>]*:1>"
|
||
prompt_txt = _remove_prompt_by_regex(pattern, prompt_txt)
|
||
elif MODEL_TYPE.EMBEDDING == type:
|
||
prompt_txt = _remove_embedding_prompt(state_value, selected_value, prompt_txt)
|
||
else:
|
||
logger.warning(f"The type {type} is not supported, skip it")
|
||
return prompt_txt
|
||
|
||
return prompt_txt
|
||
|
||
|
||
def fake_gan(selected_value, original_prompt):
|
||
logger.debug(f"selected value is {selected_value}")
|
||
logger.debug(f"original prompt is {original_prompt}")
|
||
if selected_value and selected_value != None_Option_For_On_Cloud_Model:
|
||
delimiter = "-->"
|
||
parts = selected_value.split(delimiter)
|
||
# Extract the InferenceJobId value
|
||
inference_job_id = parts[3].strip()
|
||
inference_job_status = parts[2].strip()
|
||
inference_job_taskType = parts[1].strip()
|
||
if inference_job_status == 'failed':
|
||
job = get_inference_job(inference_job_id)
|
||
return [], [], plaintext_to_html(f"inference is failed: {job['sagemakerRaw']}"), original_prompt
|
||
if inference_job_status != 'succeed':
|
||
return [], [], plaintext_to_html(f'inference is {inference_job_status}'), original_prompt
|
||
|
||
if inference_job_taskType in ["txt2img", "img2img"]:
|
||
prompt_txt = original_prompt
|
||
# output directory mapping to task type
|
||
job = get_inference_job(inference_job_id)
|
||
images = job['img_presigned_urls']
|
||
inference_param_json_list = job['output_presigned_urls']
|
||
image_list = download_images_to_pil(images)
|
||
images_to_json = download_images_to_json(inference_param_json_list)
|
||
# maybe param json was deleted
|
||
if len(images_to_json) == 0:
|
||
json_file = ""
|
||
else:
|
||
json_file = images_to_json[0]
|
||
if json_file:
|
||
info_text = json_file
|
||
infotexts = f"Inference id is {inference_job_id}\n{get_infer_job_time(job)}" + \
|
||
json.loads(info_text)["infotexts"][0]
|
||
else:
|
||
logger.debug(f"File {json_file} does not exist.")
|
||
info_text = 'something wrong when trying to download the inference parameters'
|
||
infotexts = info_text
|
||
elif inference_job_taskType in ["interrogate_clip", "interrogate_deepbooru"]:
|
||
job_status = get_inference_job(inference_job_id)
|
||
logger.debug(job_status)
|
||
caption = job_status['caption']
|
||
prompt_txt = caption
|
||
image_list = [] # Return an empty list if selected_value is None
|
||
json_list = []
|
||
info_text = ''
|
||
infotexts = ''
|
||
else:
|
||
prompt_txt = original_prompt
|
||
image_list = [] # Return an empty list if selected_value is None
|
||
json_list = []
|
||
info_text = ''
|
||
infotexts = ''
|
||
return image_list, info_text, plaintext_to_html(infotexts), prompt_txt
|
||
|
||
|
||
def get_strptime(time_str):
|
||
time_str = time_str.replace("T", " ")
|
||
return datetime.strptime(time_str, '%Y-%m-%d %H:%M:%S.%f')
|
||
|
||
|
||
def get_infer_job_time(job):
|
||
string_array = []
|
||
|
||
inference_type = ""
|
||
if 'inference_type' in job:
|
||
inference_type = job['inference_type'] + " "
|
||
|
||
if 'createTime' in job and 'completeTime' in job:
|
||
complete_time = get_strptime(job['completeTime'])
|
||
create_time = get_strptime(job['createTime'])
|
||
duration = complete_time - create_time
|
||
duration = round(duration.total_seconds(), 2)
|
||
string = f"End-to-end API Duration: {duration} seconds"
|
||
|
||
start_time = get_strptime(job['startTime'])
|
||
duration = complete_time - start_time
|
||
duration = round(duration.total_seconds(), 2)
|
||
string_array.append(f"{string} (in which {inference_type}Inference: {duration} seconds)")
|
||
else:
|
||
if 'startTime' in job and 'completeTime' in job:
|
||
complete_time = get_strptime(job['completeTime'])
|
||
start_time = get_strptime(job['startTime'])
|
||
duration = complete_time - start_time
|
||
duration = round(duration.total_seconds(), 2)
|
||
string_array.append(f"{inference_type}Inference Time: {duration} seconds")
|
||
|
||
if 'params' in job:
|
||
if 'sagemaker_inference_endpoint_name' in job['params']:
|
||
endpoint_name = job['params']['sagemaker_inference_endpoint_name']
|
||
infer_ep_name = f"Endpoint: {endpoint_name}"
|
||
if 'sagemaker_inference_instance_type' in job['params']:
|
||
instance_type = job['params']['sagemaker_inference_instance_type']
|
||
infer_ep_name += f" ({instance_type})"
|
||
string_array.append(infer_ep_name)
|
||
|
||
if len(string_array) == 0:
|
||
return ""
|
||
|
||
return "\n".join(string_array) + "\n"
|
||
|
||
|
||
def delete_inference_job(selected_value):
|
||
logger.debug(f"selected value is {selected_value}")
|
||
if selected_value and selected_value != None_Option_For_Infer_Job:
|
||
if selected_value == 'cancelled':
|
||
return
|
||
delimiter = "-->"
|
||
parts = selected_value.split(delimiter)
|
||
# Extract the InferenceJobId value
|
||
inference_job_id = parts[3].strip()
|
||
resp = api.delete_inferences(data={
|
||
"inference_id_list": [inference_job_id],
|
||
})
|
||
if resp.status_code != 204:
|
||
gr.Error(f"Error deleting inference: {resp.json()['message']}")
|
||
gr.Info(f"{inference_job_id} deleted successfully")
|
||
file_path = f"{os.getcwd()}/outputs/{inference_job_id}.md"
|
||
if os.path.exists(file_path):
|
||
os.remove(file_path)
|
||
file_path = f"{os.getcwd()}/outputs/{inference_job_id}.html"
|
||
if os.path.exists(file_path):
|
||
os.remove(file_path)
|
||
else:
|
||
gr.Warning('Please select a inference job to delete')
|
||
|
||
|
||
def init_refresh_resource_list_from_cloud(username):
|
||
logger.debug(f"start refreshing resource list from cloud")
|
||
if get_variable_from_json('api_gateway_url') is not None:
|
||
if not cloud_auth_manager.enableAuth:
|
||
refresh_all_models(username)
|
||
else:
|
||
logger.debug('auth enabled, not preload load any model')
|
||
else:
|
||
logger.debug(f"there is no api-gateway url and token in local file,")
|
||
|
||
|
||
def on_txt_time_change(start_time, end_time):
|
||
logger.debug(f"!!!!!!!!!on_txt_time_change!!!!!!{start_time},{end_time}")
|
||
global start_time_picker_txt_value
|
||
global end_time_picker_txt_value
|
||
start_time_picker_txt_value = start_time
|
||
end_time_picker_txt_value = end_time
|
||
return query_inference_job_list(txt_task_type, txt_status, txt_endpoint, txt_checkpoint, "txt2img")
|
||
|
||
|
||
def on_img_time_change(start_time, end_time):
|
||
logger.debug(f"!!!!!!!!!on_img_time_change!!!!!!{start_time},{end_time}")
|
||
global start_time_picker_img_value
|
||
global end_time_picker_img_value
|
||
start_time_picker_img_value = start_time
|
||
end_time_picker_img_value = end_time
|
||
return query_inference_job_list(img_task_type, img_status, img_endpoint, img_checkpoint, "img2img")
|
||
|
||
|
||
def load_inference_job_list(target_task_type, username, first_load="first"):
|
||
inference_jobs = [None_Option_For_On_Cloud_Model]
|
||
inferences_jobs_list = api_manager.list_all_inference_jobs_on_cloud(target_task_type, username, first_load)
|
||
|
||
temp_list = []
|
||
for obj in inferences_jobs_list:
|
||
if obj.get('createTime') is None:
|
||
complete_time = obj.get('startTime')
|
||
else:
|
||
complete_time = obj.get('createTime')
|
||
status = obj.get('status')
|
||
task_type = obj.get('taskType', 'txt2img')
|
||
inference_job_id = obj.get('InferenceJobId')
|
||
# Compatible with lower versions of APIs without type filters
|
||
if target_task_type == task_type:
|
||
temp_list.append((complete_time, f"{complete_time}-->{task_type}-->{status}-->{inference_job_id}"))
|
||
# Append the sorted combined strings to the txt2img_inference_job_ids list
|
||
for item in temp_list:
|
||
inference_jobs.append(item[1])
|
||
|
||
return inference_jobs
|
||
|
||
|
||
def load_model_list(username):
|
||
models_on_cloud = []
|
||
if 'sd_model_checkpoint' in opts.quicksettings_list:
|
||
models_on_cloud = [None_Option_For_On_Cloud_Model]
|
||
|
||
models_on_cloud += list(set([model['name'] for model in api_manager.list_models_on_cloud(username)]))
|
||
return models_on_cloud
|
||
|
||
|
||
def load_lora_models(username):
|
||
return list(set([model['name'] for model in api_manager.list_models_on_cloud(username, types='Lora')]))
|
||
|
||
|
||
def load_hypernetworks_models(username):
|
||
return list(set([model['name'] for model in api_manager.list_models_on_cloud(username, types='hypernetworks')]))
|
||
|
||
|
||
def load_vae_list(username):
|
||
vae_model_on_cloud = ['Automatic', 'None']
|
||
vae_model_on_cloud += list(set([model['name'] for model in api_manager.list_models_on_cloud(username, types='VAE')]))
|
||
return vae_model_on_cloud
|
||
|
||
|
||
def load_controlnet_list(username):
|
||
controlnet_model_on_cloud = ['None']
|
||
controlnet_model_on_cloud += list(
|
||
set([model['name'] for model in api_manager.list_models_on_cloud(username, types='ControlNet')]))
|
||
return controlnet_model_on_cloud
|
||
|
||
|
||
def load_xyz_controlnet_list(username):
|
||
controlnet_model_on_cloud = ['None']
|
||
controlnet_model_on_cloud += list(
|
||
set([os.path.splitext(model['name'])[0] for model in
|
||
api_manager.list_models_on_cloud(username, types='ControlNet')]))
|
||
return controlnet_model_on_cloud
|
||
|
||
|
||
def load_embeddings_list(username):
|
||
embedding_model_on_cloud = list(
|
||
set([model['name'] for model in api_manager.list_models_on_cloud(username, types='embeddings')]))
|
||
return embedding_model_on_cloud
|
||
|
||
|
||
def create_ui(is_img2img):
|
||
global txt2img_gallery, txt2img_generation_info
|
||
import modules.ui
|
||
|
||
# init_refresh_resource_list_from_cloud()
|
||
|
||
with gr.Blocks() as sagemaker_inference_tab:
|
||
gr.HTML('<h3>Amazon SageMaker Inference</h3>')
|
||
sagemaker_html_log = gr.HTML(elem_id=f'html_log_sagemaker')
|
||
inference_task_type = 'txt2img' if not is_img2img else 'img2img'
|
||
with gr.Column():
|
||
with gr.Row():
|
||
global lora_and_hypernet_models_state
|
||
lora_and_hypernet_models_state = gr.State({})
|
||
sd_model_on_cloud_dropdown = gr.Dropdown(choices=[], value=None_Option_For_On_Cloud_Model,
|
||
label='Stable Diffusion Checkpoint Used on Cloud')
|
||
|
||
create_refresh_button_by_user(sd_model_on_cloud_dropdown,
|
||
lambda *args: None,
|
||
lambda username: {
|
||
'choices': load_model_list(username)
|
||
}, 'refresh_cloud_model_down')
|
||
|
||
infer_endpoint_dropdown = gr.Dropdown(choices=["Async", "Real-time"],
|
||
value="Async",
|
||
label='Inference Endpoint Type')
|
||
with gr.Row():
|
||
with gr.Column():
|
||
with gr.Row():
|
||
sd_vae_on_cloud_dropdown = gr.Dropdown(choices=[], value='Automatic',
|
||
label='SD Vae on Cloud')
|
||
|
||
create_refresh_button_by_user(sd_vae_on_cloud_dropdown,
|
||
lambda *args: None,
|
||
lambda username: {
|
||
'choices': load_vae_list(username)
|
||
}, 'refresh_cloud_vae_down')
|
||
with gr.Column():
|
||
with gr.Row():
|
||
# Lora model
|
||
global lora_dropdown
|
||
lora_dropdown_local = gr.Dropdown(choices=[],
|
||
label="Lora model on cloud",
|
||
multiselect=True)
|
||
create_refresh_button_by_user(lora_dropdown_local,
|
||
lambda *args: None,
|
||
lambda username: {
|
||
'choices': load_lora_models(username)
|
||
}, 'refresh_lora_dropdown')
|
||
lora_dropdown = lora_dropdown_local
|
||
|
||
with gr.Row():
|
||
with gr.Column():
|
||
with gr.Row():
|
||
# Embedding model
|
||
global embedding_dropdown
|
||
embedding_dropdown_local = gr.Dropdown(choices=[],
|
||
label="Embedding on cloud",
|
||
multiselect=True)
|
||
create_refresh_button_by_user(embedding_dropdown_local,
|
||
lambda *args: None,
|
||
lambda username: {
|
||
'choices': load_embeddings_list(username)
|
||
}, 'refresh_embedding_dropdown')
|
||
embedding_dropdown = embedding_dropdown_local
|
||
with gr.Column():
|
||
with gr.Row():
|
||
# Hypernetwork model
|
||
global hypernet_dropdown
|
||
hypernet_dropdown_local = gr.Dropdown(choices=[],
|
||
label="Hypernetwork on cloud",
|
||
multiselect=True)
|
||
create_refresh_button_by_user(hypernet_dropdown_local,
|
||
lambda *args: None,
|
||
lambda username: {
|
||
'choices': load_hypernetworks_models(username)
|
||
}, 'refresh_hypernet_dropdown')
|
||
hypernet_dropdown = hypernet_dropdown_local
|
||
|
||
with gr.Row(visible=is_img2img):
|
||
gr.HTML('<br/>')
|
||
|
||
with gr.Row():
|
||
global inference_job_dropdown
|
||
|
||
inference_job_dropdown = gr.Dropdown(choices=[], value=None_Option_For_Infer_Job,
|
||
label="Inference Job Histories: Time-Type-Status-UUID")
|
||
create_refresh_button_by_user(inference_job_dropdown,
|
||
lambda *args: None,
|
||
lambda username: {
|
||
'choices': load_inference_job_list(inference_task_type,
|
||
username, "first")
|
||
}, 'refresh_inference_job_down')
|
||
|
||
create_refresh_button_by_user(inference_job_dropdown,
|
||
lambda *args: None,
|
||
lambda username: {
|
||
'choices': load_inference_job_list(inference_task_type,
|
||
username, "previous")
|
||
}, 'refresh_inference_job_down_previous', '←')
|
||
|
||
create_refresh_button_by_user(inference_job_dropdown,
|
||
lambda *args: None,
|
||
lambda username: {
|
||
'choices': load_inference_job_list(inference_task_type,
|
||
username, "next")
|
||
}, 'refresh_inference_job_down_next', '→')
|
||
|
||
delete_inference_job_button = ToolButton(value='\U0001F5D1', elem_id="delete_inference_job")
|
||
delete_inference_job_button.click(
|
||
_js="delete_inference_job_confirm",
|
||
fn=delete_inference_job,
|
||
inputs=[inference_job_dropdown],
|
||
outputs=[]
|
||
)
|
||
|
||
api_inference_job_button = ToolButton(value='API', elem_id="api_inference_job")
|
||
api_inference_job_cwd = ToolButton(value=os.getcwd(), elem_id="api_inference_job_path", visible=False)
|
||
api_inference_job_button.click(
|
||
_js="download_inference_job_api_call",
|
||
fn=None,
|
||
inputs=[api_inference_job_cwd, inference_job_dropdown],
|
||
outputs=[]
|
||
)
|
||
|
||
with gr.Row():
|
||
|
||
def setup_inference_for_plugin(pr: gr.Request):
|
||
models_on_cloud = load_model_list(pr.username)
|
||
vae_model_on_cloud = load_vae_list(pr.username)
|
||
lora_models_on_cloud = load_lora_models(username=pr.username)
|
||
hypernetworks_models_on_cloud = load_hypernetworks_models(pr.username)
|
||
embedding_model_on_cloud = load_embeddings_list(pr.username)
|
||
controlnet_list = load_controlnet_list(pr.username)
|
||
controlnet_xyz_list = load_xyz_controlnet_list(pr.username)
|
||
|
||
inference_jobs = load_inference_job_list(inference_task_type, pr.username)
|
||
lora_hypernets = {
|
||
"lora": lora_models_on_cloud,
|
||
"hypernet": hypernetworks_models_on_cloud,
|
||
"controlnet": controlnet_list,
|
||
"controlnet_xyz": controlnet_xyz_list,
|
||
"vae": vae_model_on_cloud,
|
||
"sd": models_on_cloud,
|
||
"embedding": embedding_model_on_cloud
|
||
}
|
||
|
||
return lora_hypernets, \
|
||
gr.update(choices=models_on_cloud, value=models_on_cloud[0] if models_on_cloud and len(models_on_cloud) > 0 else None_Option_For_On_Cloud_Model), \
|
||
gr.update(choices=inference_jobs), \
|
||
gr.update(choices=vae_model_on_cloud), \
|
||
gr.update(choices=lora_models_on_cloud), \
|
||
gr.update(choices=hypernetworks_models_on_cloud), \
|
||
gr.update(choices=embedding_model_on_cloud)
|
||
|
||
sagemaker_inference_tab.load(fn=setup_inference_for_plugin, inputs=[],
|
||
outputs=[
|
||
lora_and_hypernet_models_state,
|
||
sd_model_on_cloud_dropdown,
|
||
inference_job_dropdown,
|
||
sd_vae_on_cloud_dropdown,
|
||
lora_dropdown_local,
|
||
hypernet_dropdown_local,
|
||
embedding_dropdown_local
|
||
])
|
||
with gr.Group():
|
||
with gr.Accordion("Open for Checkpoint Merge in the Cloud!", visible=False, open=False):
|
||
sagemaker_html_log = gr.HTML(elem_id=f'html_log_sagemaker')
|
||
with FormRow(elem_id="modelmerger_models_in_the_cloud"):
|
||
global primary_model_name
|
||
primary_model_name = gr.Dropdown(elem_id="modelmerger_primary_model_name_in_the_cloud",
|
||
label="Primary model (A) in the cloud")
|
||
create_refresh_button_by_user(primary_model_name, lambda *args: None,
|
||
lambda username: {"choices": sorted(update_sd_checkpoints(username))},
|
||
"refresh_checkpoint_A_in_the_cloud")
|
||
|
||
global secondary_model_name
|
||
secondary_model_name = gr.Dropdown(elem_id="modelmerger_secondary_model_name_in_the_cloud",
|
||
label="Secondary model (B) in the cloud")
|
||
create_refresh_button_by_user(secondary_model_name, lambda *args: None,
|
||
lambda username: {"choices": sorted(update_sd_checkpoints(username))},
|
||
"refresh_checkpoint_B_in_the_cloud")
|
||
|
||
global tertiary_model_name
|
||
tertiary_model_name = gr.Dropdown(elem_id="modelmerger_tertiary_model_name_in_the_cloud",
|
||
label="Tertiary model (C) in the cloud")
|
||
create_refresh_button_by_user(tertiary_model_name, lambda *args: None,
|
||
lambda username: {"choices": sorted(update_sd_checkpoints(username))},
|
||
"refresh_checkpoint_C_in_the_cloud")
|
||
with gr.Row():
|
||
global modelmerger_merge_on_cloud
|
||
modelmerger_merge_on_cloud = gr.Button(elem_id="modelmerger_merge_in_the_cloud", value="Merge on Cloud",
|
||
variant='primary')
|
||
|
||
return sd_model_on_cloud_dropdown, infer_endpoint_dropdown, sd_vae_on_cloud_dropdown, inference_job_dropdown, primary_model_name, secondary_model_name, tertiary_model_name, modelmerger_merge_on_cloud, lora_and_hypernet_models_state
|