stable-diffusion-aws-extension/aws_extension/sagemaker_ui.py

1352 lines
62 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import logging
import os
import html
import boto3
import time
import json
import gradio
import requests
import base64
import gradio as gr
import utils
from aws_extension.auth_service.simple_cloud_auth import cloud_auth_manager
from aws_extension.cloud_api_manager.api_manager import api_manager
from aws_extension.sagemaker_ui_utils import create_refresh_button_by_user
from modules.shared import opts
from modules.ui_components import FormRow
from utils import get_variable_from_json, upload_multipart_files_to_s3_by_signed_url
from requests.exceptions import JSONDecodeError
from datetime import datetime
import math
import re
import asyncio
import nest_asyncio
from utils import cp, tar, rm
logger = logging.getLogger(__name__)
logger.setLevel(utils.LOGGING_LEVEL)
None_Option_For_On_Cloud_Model = "don't use on cloud inference"
inference_job_dropdown = None
textual_inversion_dropdown = None
hyperNetwork_dropdown = None
lora_dropdown = None
# sagemaker_endpoint = None
modelmerger_merge_on_cloud = None
interrogate_clip_on_cloud_button = None
interrogate_deep_booru_on_cloud_button = None
primary_model_name = None
secondary_model_name = None
tertiary_model_name = None
# TODO: convert to dynamically init the following variables
txt2img_inference_job_ids = []
sd_checkpoints = []
textual_inversion_list = []
lora_list = []
hyperNetwork_list = []
ControlNet_model_list = []
# Initial checkpoints information
checkpoint_info = {}
checkpoint_type = ["Stable-diffusion", "embeddings", "Lora", "hypernetworks", "ControlNet", "VAE"]
checkpoint_name = ["stable_diffusion", "embeddings", "lora", "hypernetworks", "controlnet", "VAE"]
stable_diffusion_list = []
embeddings_list = []
hypernetworks_list = []
controlnet_list = []
for ckpt_type, ckpt_name in zip(checkpoint_type, checkpoint_name):
checkpoint_info[ckpt_type] = {}
# get api_gateway_url
api_gateway_url = get_variable_from_json('api_gateway_url')
api_key = get_variable_from_json('api_token')
start_time_picker_img_value = None
end_time_picker_img_value = None
start_time_picker_txt_value = None
end_time_picker_txt_value = None
txt_task_type = None
txt_status = None
txt_endpoint = None
txt_checkpoint = None
img_task_type = None
img_status = None
img_endpoint = None
img_checkpoint = None
show_all_inference_job = False
modelTypeMap = {
'SD Checkpoints': 'Stable-diffusion',
'Textual Inversion': 'embeddings',
'LoRA model': 'Lora',
'ControlNet model': 'ControlNet',
'Hypernetwork': 'hypernetworks',
'VAE': 'VAE'
}
def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
return text
def server_request_post(path, params):
api_gateway_url = get_variable_from_json('api_gateway_url')
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
api_key = get_variable_from_json('api_token')
headers = {
"x-api-key": api_key,
"Content-Type": "application/json"
}
list_endpoint_url = f'{api_gateway_url}{path}'
response = requests.post(list_endpoint_url, json=params, headers=headers)
return response
def get_s3_file_names(bucket, folder):
"""Get a list of file names from an S3 bucket and folder."""
s3 = boto3.resource('s3')
bucket = s3.Bucket(bucket)
objects = bucket.objects.filter(Prefix=folder)
names = [obj.key for obj in objects]
return names
def get_current_date():
today = datetime.today()
formatted_date = today.strftime('%Y-%m-%d')
return formatted_date
def server_request(path):
api_gateway_url = get_variable_from_json('api_gateway_url')
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
api_key = get_variable_from_json('api_token')
headers = {
"x-api-key": api_key,
"Content-Type": "application/json"
}
list_endpoint_url = f'{api_gateway_url}{path}'
response = requests.get(list_endpoint_url, headers=headers)
return response
def datetime_to_short_form(datetime_str):
dt = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S.%f")
short_form = dt.strftime("%Y-%m-%d-%H-%M-%S")
return short_form
def query_page_inference_job_list(task_type: str, status: str, endpoint: str, checkpoint: str, is_img2img: bool,
show_all: bool):
global show_all_inference_job
show_all_inference_job = show_all
if is_img2img:
global img_task_type
img_task_type = task_type
global img_status
img_status = status
global img_endpoint
img_endpoint = endpoint
global img_checkpoint
img_checkpoint = checkpoint
opt_type = 'img2img'
logger.debug(f"{opt_type} {img_task_type} {img_status} {img_endpoint} {img_checkpoint} {show_all}")
return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type)
else:
global txt_task_type
txt_task_type = task_type
global txt_status
txt_status = txt_status
global txt_endpoint
txt_endpoint = endpoint
global txt_checkpoint
txt_checkpoint = checkpoint
opt_type = 'txt2img'
logger.debug(f"{opt_type} {txt_task_type} {txt_status} {txt_endpoint} {txt_checkpoint} {show_all}")
return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type)
def query_img_inference_job_list(task_type: str, status: str, endpoint: str, checkpoint: str):
opt_type = 'img2img'
global img_task_type
img_task_type = task_type
global img_status
img_status = status
global img_endpoint
img_endpoint = endpoint
global img_checkpoint
img_checkpoint = checkpoint
return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type)
def query_txt_inference_job_list(task_type: str, status: str, endpoint: str, checkpoint: str):
opt_type = 'txt2img'
global txt_task_type
txt_task_type = task_type
global txt_status
txt_status = status
global txt_endpoint
txt_endpoint = endpoint
global txt_checkpoint
txt_checkpoint = checkpoint
return query_inference_job_list(task_type, status, endpoint, checkpoint, opt_type)
def query_inference_job_list(task_type: str = '', status: str = '',
endpoint: str = '', checkpoint: str = '', opt_type: str = ''):
logger.debug(
f"query_inference_job_list start{status},{task_type},{endpoint},{checkpoint},{start_time_picker_txt_value},{end_time_picker_txt_value} {show_all_inference_job}")
try:
body_params = {}
if status:
body_params['status'] = status
if task_type:
body_params['task_type'] = task_type
if opt_type == 'txt2img':
if start_time_picker_txt_value:
body_params['start_time'] = start_time_picker_txt_value
if end_time_picker_txt_value:
body_params['end_time'] = end_time_picker_txt_value
elif opt_type == 'img2img':
if start_time_picker_img_value:
body_params['start_time'] = start_time_picker_img_value
if end_time_picker_img_value:
body_params['end_time'] = end_time_picker_img_value
if endpoint:
endpoint_name_array = endpoint.split("+")
if len(endpoint_name_array) > 0:
body_params['endpoint'] = endpoint_name_array[0]
if checkpoint:
body_params['checkpoint'] = checkpoint
body_params['limit'] = -1 if show_all_inference_job else 10
response = server_request_post(f'inference/query-inference-jobs', body_params)
r = response.json()
logger.debug(r)
if r:
txt2img_inference_job_ids.clear() # Clear the existing list before appending new values
temp_list = []
for obj in r:
if obj.get('completeTime') is None:
complete_time = obj.get('startTime')
else:
complete_time = obj.get('completeTime')
status = obj.get('status')
task_type = obj.get('taskType', 'txt2img')
inference_job_id = obj.get('InferenceJobId')
combined_string = f"{complete_time}-->{task_type}-->{status}-->{inference_job_id}"
temp_list.append((complete_time, combined_string))
# Sort the list based on completeTime in ascending order
sorted_list = sorted(temp_list, key=lambda x: x[0], reverse=False)
# Append the sorted combined strings to the txt2img_inference_job_ids list
for item in sorted_list:
txt2img_inference_job_ids.append(item[1])
# inference_job_dropdown.update(choices=txt2img_inference_job_ids)
return gr.Dropdown.update(choices=txt2img_inference_job_ids)
else:
logger.info("The API response is empty.")
return gr.Dropdown.update(choices=[])
except Exception as e:
logger.error("Exception occurred when fetching inference_job_ids")
logger.error(e)
return gr.Dropdown.update(choices=[])
def get_inference_job(inference_job_id):
response = server_request(f'inference/get-inference-job?jobID={inference_job_id}')
return response.json()
def get_inference_job_image_output(inference_job_id):
try:
response = server_request(f'inference/get-inference-job-image-output?jobID={inference_job_id}')
r = response.json()
txt2img_inference_job_image_list = []
for obj in r:
obj_value = str(obj)
txt2img_inference_job_image_list.append(obj_value)
return txt2img_inference_job_image_list
except Exception as e:
logger.error(f"An error occurred while getting inference job image output: {e}")
return []
def get_inference_job_param_output(inference_job_id):
try:
response = server_request(f'inference/get-inference-job-param-output?jobID={inference_job_id}')
r = response.json()
txt2img_inference_job_param_list = []
for obj in r:
obj_value = str(obj)
txt2img_inference_job_param_list.append(obj_value)
return txt2img_inference_job_param_list
except Exception as e:
logger.error(f"An error occurred while getting inference job param output: {e}")
return []
def download_images(image_urls: list, local_directory: str):
if not os.path.exists(local_directory):
os.makedirs(local_directory)
image_list = []
for url in image_urls:
try:
response = requests.get(url)
response.raise_for_status()
image_name = os.path.basename(url).split('?')[0]
local_path = os.path.join(local_directory, image_name)
with open(local_path, 'wb') as f:
f.write(response.content)
image_list.append(local_path)
except requests.exceptions.RequestException as e:
logger.error(f"Error downloading image {url}: {e}")
return image_list
def download_images_to_json(image_urls: list):
results = []
for url in image_urls:
try:
response = requests.get(url)
response.raise_for_status()
json_resp = response.json()
results.append(json_resp['info'])
except requests.exceptions.RequestException as e:
logger.error(f"Error downloading image {url}: {e}")
return results
def download_images_to_pil(image_urls: list):
image_list = []
for url in image_urls:
try:
response = requests.get(url)
response.raise_for_status()
from PIL import Image
import io
pil_image = Image.open(io.BytesIO(response.content))
image_list.append(pil_image)
except requests.exceptions.RequestException as e:
logger.error(f"Error downloading image {url}: {e}")
return image_list
def get_model_list_by_type(model_type, username=""):
api_gateway_url = get_variable_from_json('api_gateway_url')
api_key = get_variable_from_json('api_token')
# check if api_gateway_url and api_key are set
if api_gateway_url is None or api_key is None:
logger.info("api_gateway_url or api_key is not set")
return []
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
url = api_gateway_url + f"checkpoints?status=Active"
if isinstance(model_type, list):
url += "&types=" + "&types=".join(model_type)
else:
url += f"&types={model_type}"
try:
encode_type = "utf-8"
response = requests.get(url=url, headers={
'x-api-key': api_key,
'Authorization': f'Bearer {base64.b16encode(username.encode(encode_type)).decode(encode_type)}'
})
response.raise_for_status()
json_response = response.json()
if "checkpoints" not in json_response.keys():
return []
checkpoint_list = []
for ckpt in json_response["checkpoints"]:
if "name" not in ckpt:
continue
if ckpt["name"] is None:
continue
if ckpt["type"] is None:
continue
ckpt_type = ckpt["type"]
for ckpt_name in ckpt["name"]:
ckpt_s3_pos = f"{ckpt['s3Location']}/{ckpt_name}"
checkpoint_info[ckpt_type][ckpt_name] = ckpt_s3_pos
checkpoint_list.append(ckpt_name)
unique_list = list(set(checkpoint_list))
return unique_list
except Exception as e:
logger.error(f"Error fetching model list: {e}")
return []
def get_checkpoints_by_type(model_type):
url = "checkpoints?status=Active"
if isinstance(model_type, list):
url += "&types=" + "&types=".join(model_type)
else:
url += f"&types={model_type}"
try:
response = server_request(url)
response.raise_for_status()
json_response = response.json()
if "checkpoints" not in json_response.keys():
return []
checkpoint_dict = {}
for ckpt in json_response["checkpoints"]:
if "name" not in ckpt:
continue
if ckpt["name"] is None:
continue
create_time = ckpt['created']
created = datetime.fromtimestamp(create_time)
for ckpt_name in ckpt["name"]:
checkpoint = [ckpt_name, created]
if ckpt_name not in checkpoint_dict:
checkpoint_dict[ckpt_name] = checkpoint
checkpoint_list = list(checkpoint_dict.values())
return checkpoint_list
except Exception as e:
logging.error(f"Error fetching checkpoints list: {e}")
return []
def update_sd_checkpoints(username):
model_type = ["Stable-diffusion"]
return get_model_list_by_type(model_type)
def get_texual_inversion_list():
model_type = "embeddings"
return get_model_list_by_type(model_type)
def get_lora_list():
model_type = "Lora"
return get_model_list_by_type(model_type)
def get_hypernetwork_list():
model_type = "hypernetworks"
return get_model_list_by_type(model_type)
def get_controlnet_model_list():
model_type = "ControlNet"
return get_model_list_by_type(model_type)
def refresh_all_models(username):
api_gateway_url = get_variable_from_json('api_gateway_url')
api_key = get_variable_from_json('api_token')
encode_type = "utf-8"
try:
for rp, name in zip(checkpoint_type, checkpoint_name):
url = api_gateway_url + f"checkpoints?status=Active&types={rp}"
response = requests.get(url=url, headers={
'x-api-key': api_key,
'Authorization': f'Bearer {base64.b16encode(username.encode(encode_type)).decode(encode_type)}',
})
json_response = response.json()
logger.debug(f"response url json for model {rp} is {json_response}")
checkpoint_info[rp] = {}
if "checkpoints" not in json_response.keys():
continue
for ckpt in json_response["checkpoints"]:
if "name" not in ckpt:
continue
if ckpt["name"] is None:
continue
ckpt_type = ckpt["type"]
checkpoint_info[ckpt_type] = {}
for ckpt_name in ckpt["name"]:
ckpt_s3_pos = f"{ckpt['s3Location']}/{ckpt_name.split(os.sep)[-1]}"
checkpoint_info[ckpt_type][ckpt_name] = ckpt_s3_pos
except Exception as e:
logger.error(f"Error refresh all models: {e}")
def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path,
controlnet_model_path, vae_path, pr: gradio.Request):
log = "start upload model to s3:"
local_paths = [sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path, controlnet_model_path,
vae_path]
logger.info(f"Refresh checkpoints before upload to get rid of duplicate uploads...")
refresh_all_models(pr.username)
for lp, rp in zip(local_paths, checkpoint_type):
if lp == "" or not lp:
continue
logger.debug(f"lp is {lp}")
model_name = lp.split(os.sep)[-1]
exist_model_list = list(checkpoint_info[rp].keys())
if model_name in exist_model_list:
logger.info(f"!!!skip to upload duplicate model {model_name}")
continue
part_size = 1000 * 1024 * 1024
file_size = os.stat(lp)
parts_number = math.ceil(file_size.st_size / part_size)
logger.info(f'!!!!!!!!!!{file_size} {parts_number}')
# local_tar_path = f'{model_name}.tar'
local_tar_path = model_name
payload = {
"checkpoint_type": rp,
"filenames": [{
"filename": local_tar_path,
"parts_number": parts_number
}],
"params": {
"message": "placeholder for chkpts upload test",
"creator": pr.username
}
}
api_gateway_url = get_variable_from_json('api_gateway_url')
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
api_key = get_variable_from_json('api_token')
logger.info(f'!!!!!!api_gateway_url {api_gateway_url}')
url = str(api_gateway_url) + "checkpoints"
logger.debug(f"Post request for upload s3 presign url: {url}")
response = requests.post(url=url, json=payload, headers={'x-api-key': api_key})
try:
response.raise_for_status()
json_response = response.json()['data']
logger.debug(f"Response json {json_response}")
s3_base = json_response["checkpoint"]["s3_location"]
checkpoint_id = json_response["checkpoint"]["id"]
logger.debug(f"Upload to S3 {s3_base}")
logger.debug(f"Checkpoint ID: {checkpoint_id}")
s3_signed_urls_resp = json_response["s3PresignUrl"][local_tar_path]
# Upload src model to S3.
if rp != "embeddings":
local_model_path_in_repo = os.sep.join(['models', rp, model_name])
else:
local_model_path_in_repo = os.sep.join([rp, model_name])
logger.debug("Pack the model file.")
cp(lp, local_model_path_in_repo, recursive=True)
if rp == "Stable-diffusion":
model_yaml_name = model_name.split('.')[0] + ".yaml"
local_model_yaml_path = os.sep.join([*lp.split(os.sep)[:-1], model_yaml_name])
local_model_yaml_path_in_repo = os.sep.join(["models", rp, model_yaml_name])
if os.path.isfile(local_model_yaml_path):
cp(local_model_yaml_path, local_model_yaml_path_in_repo, recursive=True)
tar(mode='c', archive=local_tar_path,
sfiles=[local_model_path_in_repo, local_model_yaml_path_in_repo], verbose=True)
else:
tar(mode='c', archive=local_tar_path, sfiles=[local_model_path_in_repo], verbose=True)
else:
tar(mode='c', archive=local_tar_path, sfiles=[local_model_path_in_repo], verbose=True)
multiparts_tags = upload_multipart_files_to_s3_by_signed_url(
local_tar_path,
s3_signed_urls_resp,
part_size
)
payload = {
"status": "Active",
"multi_parts_tags": {local_tar_path: multiparts_tags}
}
# Start creating model on cloud.
response = requests.put(url=f"{url}/{checkpoint_id}", json=payload, headers={'x-api-key': api_key})
s3_input_path = s3_base
logger.debug(response)
log = f"\n finish upload {local_tar_path} to {s3_base}"
# os.system(f"rm {local_tar_path}")
rm(local_tar_path, recursive=True)
except Exception as e:
logger.error(f"fail to upload model {lp}, error: {e}")
logger.debug(f"Refresh checkpoints after upload...")
refresh_all_models(pr.username)
return log, None, None, None, None, None, None
def sagemaker_upload_model_s3_local():
log = "Start upload:"
return log
def sagemaker_upload_model_s3_url(model_type: str, url_list: str, params: str, pr: gradio.Request):
model_type = modelTypeMap.get(model_type)
if not model_type:
return "Please choose the model type."
url_pattern = r'(https?|ftp)://[^\s/$.?#].[^\s]*'
if re.match(f'^{url_pattern}$', url_list):
url_list = url_list.split(',')
else:
return "Please fill in right url list."
if params:
params_dict = json.loads(params)
else:
params_dict = {}
params_dict['creator'] = pr.username
body_params = {'checkpointType': model_type, 'modelUrl': url_list, 'params': params_dict}
response = server_request_post('upload_checkpoint', body_params)
response_data = response.json()
logging.info(f"sagemaker_upload_model_s3_url response:{response_data}")
log = "uploading……"
if 'checkpoint' in response_data:
if response_data['checkpoint'].get('status') == 'Active':
log = "upload success!"
return log
def generate_on_cloud(sagemaker_endpoint):
logger.info(f"checkpiont_info {checkpoint_info}")
logger.info(f"sagemaker endpoint {sagemaker_endpoint}")
text = "failed to check endpoint"
return plaintext_to_html(text)
# create a global event loop and apply the patch to allow nested event loops in single thread
loop = asyncio.get_event_loop()
nest_asyncio.apply()
MAX_RUNNING_LIMIT = 10
def async_loop_wrapper(f):
global loop
# check if there are any running or queued tasks inside the event loop
if loop.is_running():
# Calculate the number of running tasks
while len([task for task in asyncio.all_tasks(loop) if not task.done()]) > MAX_RUNNING_LIMIT:
logger.debug(f'Waiting for {MAX_RUNNING_LIMIT} running tasks to complete')
time.sleep(1)
else:
# check if loop is closed and create a new one
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# log this event since it should never happen
logger.debug('Event loop was closed, created a new one')
# Add new task to the event loop
result = loop.run_until_complete(f())
return result
def async_loop_wrapper_with_input(sagemaker_endpoint, type):
global loop
# check if there are any running or queued tasks inside the event loop
if loop.is_running():
# Calculate the number of running tasks
while len([task for task in asyncio.all_tasks(loop) if not task.done()]) > MAX_RUNNING_LIMIT:
logger.debug(f'Waiting for {MAX_RUNNING_LIMIT} running tasks to complete')
time.sleep(1)
else:
# check if loop is closed and create a new one
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# log this event since it should never happen
logger.debug('Event loop was closed, created a new one')
# Add new task to the event loop
result = loop.run_until_complete(call_remote_inference(sagemaker_endpoint, type))
return result
def call_interrogate_clip(sagemaker_endpoint, init_img, sketch, init_img_with_mask, inpaint_color_sketch,
init_img_inpaint, init_mask_inpaint):
return async_loop_wrapper_with_input(sagemaker_endpoint, 'interrogate_clip')
def call_interrogate_deepbooru(sagemaker_endpoint, init_img, sketch, init_img_with_mask, inpaint_color_sketch,
init_img_inpaint, init_mask_inpaint):
return async_loop_wrapper_with_input(sagemaker_endpoint, 'interrogate_deepbooru')
async def call_remote_inference(sagemaker_endpoint, type):
logger.debug(f"chosen ep {sagemaker_endpoint}")
logger.debug(f"inference type is {type}")
if sagemaker_endpoint == '':
image_list = [] # Return an empty list if selected_value is None
info_text = ''
infotexts = "Failed! Please choose the endpoint in 'InService' states "
return image_list, info_text, plaintext_to_html(infotexts)
elif sagemaker_endpoint == 'FAILURE':
image_list = [] # Return an empty list if selected_value is None
info_text = ''
infotexts = "Failed upload the config to cloud "
return image_list, info_text, plaintext_to_html(infotexts)
sagemaker_endpoint_status = sagemaker_endpoint.split("+")[1]
if sagemaker_endpoint_status != "InService":
image_list = [] # Return an empty list if selected_value is None
info_text = ''
infotexts = "Failed! Please choose the endpoint in 'InService' states "
return image_list, info_text, plaintext_to_html(infotexts)
api_gateway_url = get_variable_from_json('api_gateway_url')
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
api_key = get_variable_from_json('api_token')
# stage 2: inference using endpoint_name
headers = {
"x-api-key": api_key,
"Content-Type": "application/json"
}
checkpoint_info['sagemaker_endpoint'] = sagemaker_endpoint.split("+")[0]
payload = checkpoint_info
payload['task_type'] = type
logger.info(f"checkpointinfo is {payload}")
inference_url = f"{api_gateway_url}inference/run-sagemaker-inference"
response = requests.post(inference_url, json=payload, headers=headers)
logger.info(f"Raw server response: {response.text}")
try:
r = response.json()
except JSONDecodeError as e:
logger.error(f"Failed to decode JSON response: {e}")
logger.error(f"Raw server response: {response.text}")
else:
inference_id = r.get('inference_id') # Assuming the response contains 'inference_id' field
try:
return process_result_by_inference_id(inference_id)
except Exception as e:
logger.error(f"Failed to get inference job {inference_id}, error: {e}")
def process_result_by_inference_id(inference_id):
image_list = [] # Return an empty list if selected_value is None
info_text = ''
infotexts = f"Inference id is {inference_id}, please check all historical inference result in 'Inference Job' dropdown list"
json_list = []
prompt_txt = ''
resp = get_inference_job(inference_id)
if resp is None:
logger.info(f"get_inference_job resp is null")
return image_list, info_text, plaintext_to_html(infotexts), infotexts
else:
logger.debug(f"get_inference_job resp is {resp}")
if 'taskType' not in resp:
raise Exception(resp)
if resp['taskType'] in ['txt2img', 'img2img', 'interrogate_clip', 'interrogate_deepbooru']:
while resp and resp['status'] == "inprogress":
time.sleep(3)
resp = get_inference_job(inference_id)
if resp is None:
logger.info(f"get_inference_job resp is null.")
return image_list, info_text, plaintext_to_html(infotexts), infotexts
if resp['status'] == "failed":
infotexts = f"Inference job {inference_id} is failed, error message: {resp['sagemakerRaw']}"
return image_list, info_text, plaintext_to_html(infotexts), infotexts
elif resp['status'] == "succeed":
if resp['taskType'] in ['interrogate_clip', 'interrogate_deepbooru']:
prompt_txt = resp['caption']
# return with default value, including image_list, info_text, infotexts
return image_list, info_text, plaintext_to_html(infotexts), prompt_txt
images = get_inference_job_image_output(inference_id.strip())
inference_param_json_list = get_inference_job_param_output(inference_id)
# todo: these not need anymore
if resp['taskType'] in ['txt2img', 'img2img']:
image_list = download_images_to_pil(images)
json_file = download_images_to_json(inference_param_json_list)[0]
if json_file:
info_text = json_file
infotexts = f"Inference id is {inference_id}\n" + json.loads(info_text)["infotexts"][0]
else:
logger.debug(f"File {json_file} does not exist.")
info_text = 'something wrong when trying to download the inference parameters'
infotexts = info_text
return image_list, info_text, plaintext_to_html(infotexts), infotexts
else:
logger.debug(f"inference job status is {resp['status']}")
return image_list, info_text, plaintext_to_html(infotexts), infotexts
else:
return image_list, info_text, plaintext_to_html(infotexts), infotexts
def modelmerger_on_cloud_func(primary_model_name, secondary_model_name, teritary_model_name):
logger.debug(f"function under development, current checkpoint_info is {checkpoint_info}")
api_gateway_url = get_variable_from_json('api_gateway_url')
# Check if api_url ends with '/', if not append it
if not api_gateway_url.endswith('/'):
api_gateway_url += '/'
api_key = get_variable_from_json('api_token')
if api_gateway_url is None:
logger.debug(f"modelmerger: failed to get the api-gateway url, can not fetch remote data")
return []
modelmerge_url = f"{api_gateway_url}inference/run-model-merge"
payload = {
"primary_model_name": primary_model_name,
"secondary_model_name": secondary_model_name,
"tertiary_model_name": teritary_model_name
}
headers = {
"x-api-key": api_key,
"Content-Type": "application/json"
}
response = requests.post(modelmerge_url, json=payload, headers=headers)
try:
r = response.json()
except JSONDecodeError as e:
logger.error(f"Failed to decode JSON response: {e}")
logger.error(f"Raw server response: {response.text}")
else:
logger.debug(f"response for rest api {r}")
# def txt2img_config_save():
# # placeholder for saving txt2img config
# pass
def displayEndpointInfo(input_string: str):
logger.debug(f"selected value is {input_string}")
if not input_string:
return
parts = input_string.split('+')
if len(parts) < 2:
return plaintext_to_html("")
endpoint_job_id, status = parts[0], parts[1]
if status == 'failed':
response = server_request(f'inference/get-endpoint-deployment-job?jobID={endpoint_job_id}')
# Do something with the response
r = response.json()
if "error" in r:
return plaintext_to_html(r["error"])
else:
return plaintext_to_html(r["EndpointDeploymentJobId"])
else:
return plaintext_to_html("")
def update_txt2imgPrompt_from_TextualInversion(selected_items, txt2img_prompt):
return update_txt2imgPrompt_from_model_select(selected_items, txt2img_prompt, 'embeddings', False)
def update_txt2imgPrompt_from_Hypernetworks(selected_items, txt2img_prompt):
return update_txt2imgPrompt_from_model_select(selected_items, txt2img_prompt, 'hypernetworks', True)
def update_txt2imgPrompt_from_Lora(selected_items, txt2img_prompt):
return update_txt2imgPrompt_from_model_select(selected_items, txt2img_prompt, 'Lora', True)
def update_txt2imgPrompt_from_model_select(selected_items, txt2img_prompt, model_name='embeddings',
with_angle_brackets=False):
logger.debug(selected_items) # example ['FastNegativeV2.pt']
logger.debug(txt2img_prompt)
logger.debug(get_model_list_by_type('embeddings'))
full_dropdown_items = get_model_list_by_type(model_name) # example ['FastNegativeV2.pt', 'okuryl3nko.pt']
# Remove extensions from selected_items and full_dropdown_items
selected_items = [item.split('.')[0] for item in selected_items]
full_dropdown_items = [item.split('.')[0] for item in full_dropdown_items]
# Loop over each item in full_dropdown_items and remove it from txt2img_prompt
type_str = ''
if model_name == 'Lora':
type_str = 'lora:'
elif model_name == 'hypernetworks':
type_str = 'hypernet:'
for item in full_dropdown_items:
if with_angle_brackets:
txt2img_prompt = re.sub(f'<{type_str}{item}:\d+>', "", txt2img_prompt).strip()
else:
txt2img_prompt = txt2img_prompt.replace(item, "").strip()
# Loop over each item in selected_items and append it to txt2img_prompt
for item in selected_items:
if with_angle_brackets:
txt2img_prompt += ' ' + '<' + type_str + item + ':1>'
else:
txt2img_prompt += ' ' + item
# Remove any leading or trailing whitespace
txt2img_prompt = txt2img_prompt.strip()
return txt2img_prompt
def fake_gan(selected_value, original_prompt):
logger.debug(f"selected value is {selected_value}")
logger.debug(f"original prompt is {original_prompt}")
if selected_value is not None:
delimiter = "-->"
parts = selected_value.split(delimiter)
# Extract the InferenceJobId value
inference_job_id = parts[3].strip()
inference_job_status = parts[2].strip()
inference_job_taskType = parts[1].strip()
if inference_job_status == 'inprogress':
return [], [], plaintext_to_html('inference still in progress')
if inference_job_taskType in ["txt2img", "img2img"]:
prompt_txt = original_prompt
# output directory mapping to task type
images = get_inference_job_image_output(inference_job_id.strip())
inference_param_json_list = get_inference_job_param_output(inference_job_id)
image_list = download_images_to_pil(images)
json_file = download_images_to_json(inference_param_json_list)[0]
if json_file:
info_text = json_file
infotexts = f"Inference id is {inference_job_id}\n" + json.loads(info_text)["infotexts"][0]
else:
logger.debug(f"File {json_file} does not exist.")
info_text = 'something wrong when trying to download the inference parameters'
infotexts = info_text
elif inference_job_taskType in ["interrogate_clip", "interrogate_deepbooru"]:
job_status = get_inference_job(inference_job_id)
logger.debug(job_status)
caption = job_status['caption']
prompt_txt = caption
image_list = [] # Return an empty list if selected_value is None
json_list = []
info_text = ''
infotexts = ''
else:
prompt_txt = original_prompt
image_list = [] # Return an empty list if selected_value is None
json_list = []
info_text = ''
infotexts = ''
return image_list, info_text, plaintext_to_html(infotexts), prompt_txt
def init_refresh_resource_list_from_cloud(username):
logger.debug(f"start refreshing resource list from cloud")
if get_variable_from_json('api_gateway_url') is not None:
if not cloud_auth_manager.enableAuth:
refresh_all_models(username)
else:
logger.debug('auth enabled, not preload load any model')
else:
logger.debug(f"there is no api-gateway url and token in local file,")
def on_txt_time_change(start_time, end_time):
logger.debug(f"!!!!!!!!!on_txt_time_change!!!!!!{start_time},{end_time}")
global start_time_picker_txt_value
global end_time_picker_txt_value
start_time_picker_txt_value = start_time
end_time_picker_txt_value = end_time
return query_inference_job_list(txt_task_type, txt_status, txt_endpoint, txt_checkpoint, "txt2img")
def on_img_time_change(start_time, end_time):
logger.debug(f"!!!!!!!!!on_img_time_change!!!!!!{start_time},{end_time}")
global start_time_picker_img_value
global end_time_picker_img_value
start_time_picker_img_value = start_time
end_time_picker_img_value = end_time
return query_inference_job_list(img_task_type, img_status, img_endpoint, img_checkpoint, "img2img")
def load_inference_job_list(target_task_type, username, usertoken):
inference_jobs = [None_Option_For_On_Cloud_Model]
inferences_jobs_list = api_manager.list_all_inference_jobs_on_cloud(username, usertoken)
temp_list = []
for obj in inferences_jobs_list:
if obj.get('completeTime') is None:
complete_time = obj.get('startTime')
else:
complete_time = obj.get('completeTime')
status = obj.get('status')
task_type = obj.get('taskType', 'txt2img')
inference_job_id = obj.get('InferenceJobId')
# if filter_checkbox and task_type not in selected_types:
# continue
if target_task_type == task_type:
temp_list.append((complete_time, f"{complete_time}-->{task_type}-->{status}-->{inference_job_id}"))
# Sort the list based on completeTime in ascending order
sorted_list = sorted(temp_list, key=lambda x: x[0], reverse=False)
# Append the sorted combined strings to the txt2img_inference_job_ids list
for item in sorted_list:
inference_jobs.append(item[1])
return inference_jobs
def load_model_list(username, user_token):
models_on_cloud = []
if 'sd_model_checkpoint' in opts.quicksettings_list:
models_on_cloud = [None_Option_For_On_Cloud_Model]
models_on_cloud += list(set([model['name'] for model in api_manager.list_models_on_cloud(username, user_token)]))
return models_on_cloud
def load_lora_models(username, user_token):
return list(set([model['name'] for model in api_manager.list_models_on_cloud(username, user_token, types='Lora')]))
def load_hypernetworks_models(username, user_token):
return list(set([model['name'] for model in api_manager.list_models_on_cloud(username, user_token, types='hypernetworks')]))
def load_vae_list(username, user_token):
vae_model_on_cloud = ['Automatic', 'None']
vae_model_on_cloud += list(set([model['name'] for model in api_manager.list_models_on_cloud(username, user_token, types='VAE')]))
return vae_model_on_cloud
def load_controlnet_list(username, user_token):
controlnet_model_on_cloud = ['None']
controlnet_model_on_cloud += list(set([model['name'] for model in api_manager.list_models_on_cloud(username, user_token, types='ControlNet')]))
return controlnet_model_on_cloud
def load_xyz_controlnet_list(username, user_token):
controlnet_model_on_cloud = ['None']
controlnet_model_on_cloud += list(
set([os.path.splitext(model['name'])[0] for model in api_manager.list_models_on_cloud(username, user_token, types='ControlNet')]))
return controlnet_model_on_cloud
def load_embeddings_list(username, user_token):
# vae_model_on_cloud = ['None']
vae_model_on_cloud = list(set([model['name'] for model in api_manager.list_models_on_cloud(username, user_token, types='embeddings')]))
return vae_model_on_cloud
def create_ui(is_img2img):
global txt2img_gallery, txt2img_generation_info
import modules.ui
# init_refresh_resource_list_from_cloud()
with gr.Blocks() as sagemaker_inference_tab:
gr.HTML('<h3>Amazon SageMaker Inference</h3>')
sagemaker_html_log = gr.HTML(elem_id=f'html_log_sagemaker')
inference_task_type = 'txt2img' if not is_img2img else 'img2img'
with gr.Column():
with gr.Row():
lora_and_hypernet_models_state = gr.State({})
sd_model_on_cloud_dropdown = gr.Dropdown(choices=[], value=None_Option_For_On_Cloud_Model,
label='Stable Diffusion Checkpoint Used on Cloud')
create_refresh_button_by_user(sd_model_on_cloud_dropdown,
lambda *args: None,
lambda username: {
'choices': load_model_list(username, username)
}, 'refresh_cloud_model_down')
with gr.Row():
sd_vae_on_cloud_dropdown = gr.Dropdown(choices=[], value='Automatic',
label='SD Vae on Cloud')
create_refresh_button_by_user(sd_vae_on_cloud_dropdown,
lambda *args: None,
lambda username: {
'choices': load_vae_list(username, username)
}, 'refresh_cloud_vae_down')
with gr.Row(visible=is_img2img):
gr.HTML('<br/>')
# with gr.Row(visible=is_img2img):
# global generate_on_cloud_button_with_js
# # if not is_img2img:
# # generate_on_cloud_button_with_js = gr.Button(value="Generate on Cloud", variant='primary', elem_id="generate_on_cloud_with_cloud_config_button",queue=True, show_progress=True)
# global generate_on_cloud_button_with_js_img2img
# global interrogate_clip_on_cloud_button
# global interrogate_deep_booru_on_cloud_button
#
# interrogate_clip_on_cloud_button = gr.Button(value="Interrogate CLIP", variant='primary',
# elem_id="interrogate_clip_on_cloud_button", visible=False)
# interrogate_deep_booru_on_cloud_button = gr.Button(value="Interrogte DeepBooru", variant='primary',
# elem_id="interrogate_deep_booru_on_cloud_button", visible=False)
# with gr.Row():
# gr.HTML('<br/>')
with gr.Row():
global inference_job_dropdown
# global txt2img_inference_job_ids
inference_job_dropdown = gr.Dropdown(choices=[], value=None_Option_For_On_Cloud_Model,
label="Inference Job: Time-Type-Status-Uuid")
create_refresh_button_by_user(inference_job_dropdown,
lambda *args: None,
lambda username: {
'choices': load_inference_job_list(inference_task_type, username, username)
}, 'refresh_inference_job_down')
# inference_job_dropdown = gr.Dropdown(choices=txt2img_inference_job_ids,
# label="Inference Job: Time-Type-Status-Uuid",
# elem_id="txt2img_inference_job_ids_dropdown"
# )
# txt2img_inference_job_ids_refresh_button = create_refresh_button(inference_job_dropdown,
# query_inference_job_list,
# lambda: {
# "choices": txt2img_inference_job_ids,
# "value": None},
# "refresh_txt2img_inference_job_ids")
# fixme: inference filters need to be fixed
# with gr.Row():
# inference_job_filter = gr.Checkbox(
# label="Advanced Inference Job filter", value=False, visible=True
# )
# inference_job_page = gr.Checkbox(label="Show All(unchecked: max 10 items)",
# elem_id="inference_job_page_checkbox", value=False)
# with gr.Row(variant='panel', visible=False) as filter_row:
# with gr.Column(scale=1):
# gr.HTML(value="Inference Job type filters")
# with gr.Column(scale=2):
# with gr.Row():
# task_type_choices = ["txt2img", "img2img", "interrogate_clip", "interrogate_deepbooru"]
# task_type_dropdown = gr.Dropdown(label="Task Type", choices=task_type_choices,
# elem_id="task_type_ids_dropdown")
# status_choices = ["succeed", "inprogress", "failed"]
# status_dropdown = gr.Dropdown(label="Status", choices=status_choices,
# elem_id="task_status_dropdown")
# # with gr.Row():
# # sagemaker_endpoint_filter = gr.Dropdown(api_manager.list_all_sagemaker_endpoints(),
# # label="SageMaker Endpoint",
# # elem_id="sagemaker_endpoint_dropdown")
# # modules.ui.create_refresh_button(sagemaker_endpoint_filter, lambda: None,
# # lambda: {"choices": api_manager.list_all_sagemaker_endpoints()},
# # "refresh_sagemaker_endpoints")
#
# with gr.Row():
# sd_checkpoint_filter = gr.Dropdown(label="Checkpoint", choices=sorted(update_sd_checkpoints()),
# elem_id="stable_diffusion_checkpoint_dropdown")
# modules.ui.create_refresh_button(sd_checkpoint_filter, update_sd_checkpoints,
# lambda: {"choices": sorted(update_sd_checkpoints())},
# "refresh_sd_checkpoints")
# if is_img2img:
# with gr.Row():
# start_time_picker_img = gr.HTML(elem_id="start_timepicker_img_e",
# value="""
# <span class="svelte-1ed2p3z" style="color: #6B7280">
# Start Time
# <input type="date"
# lang="en"
# id="start_timepicker_img"
# min="2023-01-01"
# max="2033-12-31"
# class="wrap svelte-aqlk7e"
# style="color: #6B7280"
# onchange="inference_job_timepicker_img_change()" />
# </span>
# """)
# end_time_picker_img = gr.HTML(elem_id="end_timepicker_img_e",
# value="""
# <span class="svelte-1ed2p3z" style="color: #6B7280">
# End Time
# <input type="date"
# lang="en"
# id="end_timepicker_img"
# min="2023-01-01" max="2033-12-31"
# class="wrap svelte-aqlk7e"
# style="color: #6B7280"
# onchange="inference_job_timepicker_img_change()">
# </span>
# """)
# start_time_picker_img_hidden = gr.Button(elem_id="start_time_picker_img_hidden",
# visible=True)
# end_time_picker_img_hidden = gr.Button(elem_id="end_time_picker_img_hidden",
# visible=True)
# start_time_picker_img_hidden.click(fn=on_img_time_change,
# _js='get_time_img_value',
# inputs=[start_time_picker_img, end_time_picker_img],
# outputs=inference_job_dropdown)
# end_time_picker_img_hidden.click(fn=on_img_time_change,
# _js='get_time_img_value',
# inputs=[start_time_picker_img, end_time_picker_img],
# outputs=inference_job_dropdown
# )
# task_type_dropdown.change(fn=query_img_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter,
# sd_checkpoint_filter], outputs=inference_job_dropdown)
# status_dropdown.change(fn=query_img_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown, sagemaker_endpoint_filter,
# sd_checkpoint_filter], outputs=inference_job_dropdown)
#
# sagemaker_endpoint_filter.change(fn=query_img_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter,
# sd_checkpoint_filter], outputs=inference_job_dropdown)
#
# sd_checkpoint_filter.change(fn=query_img_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter,
# sd_checkpoint_filter], outputs=inference_job_dropdown)
# else:
# with gr.Row():
# start_time_picker_text = gr.HTML(elem_id="start_timepicker_text_e",
# value="""
# <span class="svelte-1ed2p3z" style="color: #6B7280">
# Start Time
# <input type="date" lang="en" id="start_timepicker_text"
# min="2023-01-01" max="2033-12-31"
# class="wrap svelte-aqlk7e" style="color: #6B7280"
# onchange="inference_job_timepicker_text_change()">
# </span>
# """)
# end_time_picker_text = gr.HTML(elem_id="end_timepicker_text_e",
# value="""
# <span class="svelte-1ed2p3z" style="color: #6B7280">
# End Time
# <input type="date" lang="en" id="end_timepicker_text"
# min="2023-01-01" max="2033-12-31"
# class="wrap svelte-aqlk7e"
# style="color: #6B7280"
# onchange="inference_job_timepicker_text_change()">
# </span>
# """)
# start_time_picker_button_hidden = gr.Button(elem_id="start_time_picker_button_hidden",
# visible=False)
# end_time_picker_button_hidden = gr.Button(elem_id="end_time_picker_button_hidden",
# visible=False)
# start_time_picker_button_hidden.click(fn=on_txt_time_change,
# _js='get_time_button_value',
# inputs=[start_time_picker_text, end_time_picker_text],
# outputs=inference_job_dropdown)
# end_time_picker_button_hidden.click(fn=on_txt_time_change,
# _js='get_time_button_value',
# inputs=[start_time_picker_text, end_time_picker_text],
# outputs=inference_job_dropdown)
# task_type_dropdown.change(fn=query_txt_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter, sd_checkpoint_filter],
# outputs=inference_job_dropdown)
# status_dropdown.change(fn=query_txt_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter,
# sd_checkpoint_filter], outputs=inference_job_dropdown)
# sagemaker_endpoint_filter.change(fn=query_txt_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter,
# sd_checkpoint_filter], outputs=inference_job_dropdown)
# sd_checkpoint_filter.change(fn=query_txt_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter, sd_checkpoint_filter],
# outputs=inference_job_dropdown)
#
# def toggle_new_rows(create_from):
# global start_time_picker_txt_value
# start_time_picker_txt_value = None
# global end_time_picker_txt_value
# end_time_picker_txt_value = None
# return [gr.update(visible=create_from), None, None, None, None]
#
# inference_job_filter.change(
# fn=toggle_new_rows,
# inputs=[inference_job_filter],
# outputs=[filter_row, task_type_dropdown, status_dropdown, sagemaker_endpoint_filter,
# sd_checkpoint_filter],
# )
# hidden_check_type = gr.Textbox(elem_id="hidden_check_type", value=is_img2img, visible=False)
# inference_job_page.change(fn=query_page_inference_job_list,
# inputs=[task_type_dropdown, status_dropdown,
# sagemaker_endpoint_filter,
# sd_checkpoint_filter, hidden_check_type,
# inference_job_page],
#
# outputs=inference_job_dropdown)
def setup_inference_for_plugin(pr: gr.Request):
models_on_cloud = load_model_list(pr.username, pr.username)
vae_model_on_cloud = load_vae_list(pr.username, pr.username)
lora_models_on_cloud = load_lora_models(username=pr.username, user_token=pr.username)
hypernetworks_models_on_cloud = load_hypernetworks_models(pr.username, pr.username)
controlnet_list = load_controlnet_list(pr.username, pr.username)
controlnet_xyz_list = load_xyz_controlnet_list(pr.username, pr.username)
inference_jobs = load_inference_job_list(inference_task_type, pr.username, pr.username)
lora_hypernets = {
'lora': lora_models_on_cloud,
'hypernet': hypernetworks_models_on_cloud,
'controlnet': controlnet_list,
'controlnet_xyz': controlnet_xyz_list,
'vae': vae_model_on_cloud,
'sd': models_on_cloud,
}
return lora_hypernets, \
gr.update(choices=models_on_cloud, value=models_on_cloud[0] if models_on_cloud and len(models_on_cloud) > 0 else None_Option_For_On_Cloud_Model), \
gr.update(choices=inference_jobs), \
gr.update(choices=vae_model_on_cloud)
sagemaker_inference_tab.load(fn=setup_inference_for_plugin, inputs=[],
outputs=[
lora_and_hypernet_models_state,
sd_model_on_cloud_dropdown,
inference_job_dropdown,
sd_vae_on_cloud_dropdown
])
with gr.Group():
with gr.Accordion("Open for Checkpoint Merge in the Cloud!", visible=False, open=False):
sagemaker_html_log = gr.HTML(elem_id=f'html_log_sagemaker')
with FormRow(elem_id="modelmerger_models_in_the_cloud"):
global primary_model_name
primary_model_name = gr.Dropdown(elem_id="modelmerger_primary_model_name_in_the_cloud",
label="Primary model (A) in the cloud")
create_refresh_button_by_user(primary_model_name, lambda *args: None,
lambda username: {"choices": sorted(update_sd_checkpoints(username))},
"refresh_checkpoint_A_in_the_cloud")
global secondary_model_name
secondary_model_name = gr.Dropdown(elem_id="modelmerger_secondary_model_name_in_the_cloud",
label="Secondary model (B) in the cloud")
create_refresh_button_by_user(secondary_model_name, lambda *args: None,
lambda username: {"choices": sorted(update_sd_checkpoints(username))},
"refresh_checkpoint_B_in_the_cloud")
global tertiary_model_name
tertiary_model_name = gr.Dropdown(elem_id="modelmerger_tertiary_model_name_in_the_cloud",
label="Tertiary model (C) in the cloud")
create_refresh_button_by_user(tertiary_model_name, lambda *args: None,
lambda username: {"choices": sorted(update_sd_checkpoints(username))},
"refresh_checkpoint_C_in_the_cloud")
with gr.Row():
global modelmerger_merge_on_cloud
modelmerger_merge_on_cloud = gr.Button(elem_id="modelmerger_merge_in_the_cloud", value="Merge on Cloud",
variant='primary')
return sd_model_on_cloud_dropdown, sd_vae_on_cloud_dropdown, inference_job_dropdown, primary_model_name, secondary_model_name, tertiary_model_name, modelmerger_merge_on_cloud, lora_and_hypernet_models_state