stable-diffusion-aws-extension/middleware_api/lambda/inference/app.py

960 lines
40 KiB
Python

import time
import logging
import logging.config
import os
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from fastapi.exception_handlers import http_exception_handler
from mangum import Mangum
from common.response_wrapper import resp_err
from common.enum import MessageEnum
from common.constant import const
from common.exception_handler import biz_exception
from fastapi_pagination import add_pagination
from datetime import datetime
from typing import List
import boto3
from botocore.client import Config
from botocore.exceptions import BotoCoreError, ClientError
import json
import uuid
from sagemaker.predictor import Predictor
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
from boto3.dynamodb.conditions import Attr, Key
logging.config.fileConfig('logging.conf', disable_existing_loggers=False)
logger = logging.getLogger(const.LOGGER_API)
STEP_FUNCTION_ARN = os.environ.get('STEP_FUNCTION_ARN')
DDB_INFERENCE_TABLE_NAME = os.environ.get('DDB_INFERENCE_TABLE_NAME')
DDB_TRAINING_TABLE_NAME = os.environ.get('DDB_TRAINING_TABLE_NAME')
DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME = os.environ.get('DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME')
REGION_NAME = os.environ['AWS_REGION']
S3_BUCKET_NAME = os.environ.get('S3_BUCKET')
ddb_client = boto3.resource('dynamodb')
s3 = boto3.client('s3', region_name=REGION_NAME)
sagemaker = boto3.client('sagemaker')
inference_table = ddb_client.Table(DDB_INFERENCE_TABLE_NAME)
endpoint_deployment_table = ddb_client.Table(DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME)
# name for utils sagemaker endpoint name
utils_endpoint_name = os.environ.get("SAGEMAKER_ENDPOINT_NAME")
async def custom_exception_handler(request: Request, exc: HTTPException):
headers = {
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "OPTIONS,POST,GET"
}
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.detail},
headers=headers
)
app = FastAPI(
title="API List of SageMaker Inference",
version="0.9",
)
app.exception_handler(HTTPException)(custom_exception_handler)
def get_uuid():
uuid_str = str(uuid.uuid4())
return uuid_str
def getInferenceJobList():
response = inference_table.scan()
logger.info(f"inference job list response is {str(response)}")
return response['Items']
def getInferenceJob(inference_job_id):
if not inference_job_id:
logger.error("Invalid inference job id")
raise ValueError("Inference job id must not be None or empty")
try:
resp = inference_table.query(
KeyConditionExpression=Key('InferenceJobId').eq(inference_job_id)
)
logger.info(resp)
record_list = resp['Items']
if len(record_list) == 0:
logger.error(f"No inference job info item for id: {inference_job_id}")
raise ValueError(f"There is no inference job info item for id: {inference_job_id}")
return record_list[0]
except Exception as e:
logger.error(f"Exception occurred when trying to query inference job with id: {inference_job_id}, exception is {str(e)}")
raise
def getEndpointDeploymentJobList():
try:
sagemaker = boto3.client('sagemaker')
ddb = boto3.resource('dynamodb')
endpoint_deployment_table = ddb.Table(DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME)
response = endpoint_deployment_table.scan()
logger.info(f"endpoint deployment job list response is {str(response)}")
# Get the list of SageMaker endpoints
list_results = sagemaker.list_endpoints()
sagemaker_endpoints = [ep_info['EndpointName'] for ep_info in list_results['Endpoints']]
logger.info(str(sagemaker_endpoints))
# Filter the endpoint job list
filtered_endpoint_jobs = []
for job in response['Items']:
if 'endpoint_name' in job:
endpoint_name = job['endpoint_name']
deployment_job_id = job['EndpointDeploymentJobId']
if endpoint_name in sagemaker_endpoints:
filtered_endpoint_jobs.append(job)
else:
# Remove the job item from the DynamoDB table if the endpoint doesn't exist in SageMaker
endpoint_deployment_table.delete_item(Key={'EndpointDeploymentJobId': deployment_job_id})
else:
filtered_endpoint_jobs.append(job)
return filtered_endpoint_jobs
except ClientError as e:
print(f"An error occurred: {e}")
return []
except Exception as e:
print(f"An unexpected error occurred: {e}")
return []
def getEndpointDeployJob(endpoint_deploy_job_id):
try:
resp = endpoint_deployment_table.query(
KeyConditionExpression=Key('EndpointDeploymentJobId').eq(endpoint_deploy_job_id)
)
logger.info(resp)
except Exception as e:
logger.error(e)
record_list = resp['Items']
if len(record_list) == 0:
logger.error("There is no endpoint deployment job info item for id:" + endpoint_deploy_job_id)
return {}
return record_list[0]
def getEndpointDeployJob_with_endpoint_name(endpoint_name):
try:
resp = endpoint_deployment_table.scan(
FilterExpression=Attr('endpoint_name').eq(endpoint_name)
)
logger.info(resp)
except Exception as e:
logger.error(e)
record_list = resp['Items']
if len(record_list) == 0:
logger.error("There is no endpoint deployment job info item with endpoint name:" + endpoint_name)
return {}
return record_list[0]
def get_s3_objects(bucket_name, folder_name):
# Ensure the folder name ends with a slash
if not folder_name.endswith('/'):
folder_name += '/'
# List objects in the specified bucket and folder
response = s3.list_objects_v2(Bucket=bucket_name, Prefix=folder_name)
# Extract object names from the response
object_names = [obj['Key'][len(folder_name):] for obj in response.get('Contents', []) if obj['Key'] != folder_name]
return object_names
def load_json_from_s3(bucket_name, key):
# Get the JSON file from the specified bucket and key
response = s3.get_object(Bucket=bucket_name, Key=key)
json_file = response['Body'].read().decode('utf-8')
# Load the JSON file into a dictionary
data = json.loads(json_file)
return data
def get_param_value(params_dict, key, defaultValue="false"):
try:
param_value = params_dict[key]
except Exception as e:
print(f"can not found {key} and use default value {defaultValue}")
param_value = defaultValue
return param_value
def json_convert_to_payload(params_dict, checkpoint_info):
# Need to generate the payload from data_dict here:
script_name = get_param_value(params_dict, 'script_list', defaultValue="None")
if script_name == "None":
script_name = ""
script_args = []
if script_name == 'Prompt matrix':
put_at_start = get_param_value(params_dict, 'script_txt2txt_prompt_matrix_put_at_start')
different_seeds = get_param_value(params_dict, 'script_txt2txt_prompt_matrix_different_seeds')
if get_param_value(params_dict, 'script_txt2txt_prompt_matrix_prompt_type_positive', defaultValue="positive"):
prompt_type = "positive"
else:
prompt_type = "negative"
if get_param_value(params_dict, 'script_txt2txt_prompt_matrix_variations_delimiter_comma', defaultValue="comma"):
variations_delimiter = "comma"
else:
variations_delimiter = "space"
margin_size = int(get_param_value(params_dict, 'script_txt2txt_prompt_matrix_margin_size', defaultValue=0))
script_args = [put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size]
if script_name == 'Prompts from file or textbox':
checkbox_iterate = get_param_value(params_dict, 'script_txt2txt_checkbox_iterate_every_line')
checkbox_iterate_batch = get_param_value(params_dict, 'script_txt2txt_checkbox_iterate_all_lines')
list_prompt_inputs = get_param_value(params_dict, 'script_txt2txt_prompts_from_file_or_textbox_prompt_txt', defaultValue="")
lines = [x.strip() for x in list_prompt_inputs.split("\n")]
script_args = [checkbox_iterate, checkbox_iterate_batch, "\n".join(lines)]
if script_name == 'X/Y/Z plot':
type_dict = {'Nothing': 0,
'Seed': 1,
'Var. seed': 2,
'Var. strength': 3,
'Steps': 4,
'Hires stteps': 5,
'CFG Scale': 6,
'Prompt S/R': 7,
'Prompt order': 8,
'Sampler': 9,
'Checkpoint name': 10,
'Negative Guidance minimum sigma': 11,
'Sigma Churn': 12,
'Sigma min': 13,
'Sigma max': 14,
'Sigma noise': 15,
'Eta': 16,
'Clip skip': 17,
'Denoising': 18,
'Hires upscaler': 19,
'VAE': 20,
'Styles': 21,
'UniPC Order': 22,
'Face restore': 23,
'[ControlNet] Enabled': 24,
'[ControlNet] Model': 25,
'[ControlNet] Weight': 26,
'[ControlNet] Guidance Start': 27,
'[ControlNet] Guidance End': 28,
'[ControlNet] Resize Mode': 29,
'[ControlNet] Preprocessor': 30,
'[ControlNet] Pre Resolution': 31,
'[ControlNet] Pre Threshold A': 32,
'[ControlNet] Pre Threshold B': 33}
dropdown_index = [9, 10, 19, 20, 21, 24, 25, 29, 30]
x_type = type_dict[get_param_value(params_dict, 'script_txt2txt_xyz_plot_x_type', defaultValue="Nothing")]
x_values = get_param_value(params_dict, 'script_txt2txt_xyz_plot_x_values', defaultValue="")
x_values_dropdown = get_param_value(params_dict, 'script_txt2txt_xyz_plot_x_values', defaultValue="")
if x_type in dropdown_index:
if x_type == 10:
x_values_dropdown = get_param_value(params_dict, 'sagemaker_stable_diffusion_checkpoint', defaultValue="None")
elif x_type == 25:
x_values_dropdown = get_param_value(params_dict, 'sagemaker_controlnet_model', defaultValue="None")
x_values_dropdown = x_values_dropdown.split(":")
y_type = type_dict[get_param_value(params_dict, 'script_txt2txt_xyz_plot_y_type', defaultValue="Nothing")]
y_values = get_param_value(params_dict, 'script_txt2txt_xyz_plot_y_values', defaultValue="")
y_values_dropdown = get_param_value(params_dict, 'script_txt2txt_xyz_plot_y_values', defaultValue="")
if y_type in dropdown_index:
if y_type == 10:
y_values_dropdown = get_param_value(params_dict, 'sagemaker_stable_diffusion_checkpoint', defaultValue="None")
elif y_type == 25:
y_values_dropdown = get_param_value(params_dict, 'sagemaker_controlnet_model', defaultValue="None")
y_values_dropdown = y_values_dropdown.split(":")
z_type = type_dict[get_param_value(params_dict, 'script_txt2txt_xyz_plot_z_type', defaultValue="Nothing")]
z_values = get_param_value(params_dict, 'script_txt2txt_xyz_plot_z_values', defaultValue="")
z_values_dropdown = get_param_value(params_dict, 'script_txt2txt_xyz_plot_z_values', defaultValue="")
if z_type in dropdown_index:
if z_type == 10:
z_values_dropdown = get_param_value(params_dict, 'sagemaker_stable_diffusion_checkpoint', defaultValue="None")
elif z_type == 25:
z_values_dropdown = get_param_value(params_dict, 'sagemaker_controlnet_model', defaultValue="None")
z_values_dropdown = z_values_dropdown.split(":")
draw_legend = get_param_value(params_dict, 'script_txt2txt_xyz_plot_draw_legend')
include_lone_images = get_param_value(params_dict, 'script_txt2txt_xyz_plot_include_lone_images')
include_sub_grids = get_param_value(params_dict, 'script_txt2txt_xyz_plot_include_sub_grids')
no_fixed_seeds = get_param_value(params_dict, 'script_txt2txt_xyz_plot_no_fixed_seeds')
margin_size = int(get_param_value(params_dict, 'script_txt2txt_xyz_plot_margin_size', defaultValue=0))
script_args = [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
# get all parameters from ui-config.json
prompt = get_param_value(params_dict, 'txt2img_prompt', defaultValue="")
negative_prompt = get_param_value(params_dict, 'txt2img_neg_prompt', defaultValue="")
enable_hr = get_param_value(params_dict, 'txt2img_enable_hr')
denoising_strength = float(get_param_value(params_dict, 'txt2img_denoising_strength', defaultValue=0.7))
hr_scale = float(get_param_value(params_dict, 'txt2img_hr_scale', defaultValue=2.0))
hr_upscaler = get_param_value(params_dict, 'txt2img_hr_upscaler', defaultValue="Latent")
hr_second_pass_steps = int(get_param_value(params_dict, 'txt2img_hires_steps', defaultValue=0))
firstphase_width = int(get_param_value(params_dict, 'txt2img_hr_resize_x', defaultValue=0))
firstphase_height = int(get_param_value(params_dict, 'txt2img_hr_resize_y', defaultValue=0))
hr_resize_x = int(get_param_value(params_dict, 'txt2img_hr_resize_x', defaultValue=0))
hr_resize_y = int(get_param_value(params_dict, 'txt2img_hr_resize_y', defaultValue=0))
styles = get_param_value(params_dict, 'txt2img_styles', defaultValue=["None", "None"])
if styles == "":
styles = []
seed = float(get_param_value(params_dict, 'txt2img_seed', defaultValue=-1.0))
subseed = float(get_param_value(params_dict, 'txt2img_subseed', defaultValue=-1.0))
subseed_strength = float(get_param_value(params_dict, 'txt2img_subseed_strength', defaultValue=0))
seed_resize_from_h = int(get_param_value(params_dict, 'txt2img_seed_resize_from_h', defaultValue=0))
seed_resize_from_w = int(get_param_value(params_dict, 'txt2img_seed_resize_from_w', defaultValue=0))
sampler_index = get_param_value(params_dict, 'txt2img_sampling_method', defaultValue="Euler a")
batch_size = int(get_param_value(params_dict, 'txt2img_batch_size', defaultValue=1))
n_iter = int(get_param_value(params_dict, 'txt2img_batch_count', defaultValue=1))
steps = int(get_param_value(params_dict, 'txt2img_steps', defaultValue=20))
cfg_scale = float(get_param_value(params_dict, 'txt2img_cfg_scale', defaultValue=7))
width = int(get_param_value(params_dict, 'txt2img_width', defaultValue=512))
height = int(get_param_value(params_dict, 'txt2img_height', defaultValue=512))
restore_faces = get_param_value(params_dict, 'txt2img_restore_faces')
tiling = get_param_value(params_dict, 'txt2img_tiling')
override_settings = {}
eta = 1
s_churn = 0
s_tmax = 1
s_tmin = 0
s_noise = 1
selected_sd_model = get_param_value(params_dict, 'sagemaker_stable_diffusion_checkpoint', defaultValue="")
selected_cn_model = get_param_value(params_dict, 'sagemaker_controlnet_model', defaultValue="")
selected_hypernets = get_param_value(params_dict, 'sagemaker_hypernetwork_model', defaultValue="")
selected_loras = get_param_value(params_dict, 'sagemaker_lora_model', defaultValue="")
selected_embeddings = get_param_value(params_dict, 'sagemaker_texual_inversion_model', defaultValue="")
if selected_sd_model == "":
selected_sd_model = ['v1-5-pruned-emaonly.safetensors']
else:
selected_sd_model = selected_sd_model.split(":")
if selected_cn_model == "":
selected_cn_model = []
else:
selected_cn_model = selected_cn_model.split(":")
if selected_hypernets == "":
selected_hypernets = []
else:
selected_hypernets = selected_hypernets.split(":")
if selected_loras == "":
selected_loras = []
else:
selected_loras = selected_loras.split(":")
if selected_embeddings == "":
selected_embeddings = []
else:
selected_embeddings = selected_embeddings.split(":")
for embedding in selected_embeddings:
if embedding not in prompt:
prompt = prompt + embedding
for hypernet in selected_hypernets:
hypernet_name = os.path.splitext(hypernet)[0]
if hypernet_name not in prompt:
prompt = prompt + f"<hypernet:{hypernet_name}:1>"
for lora in selected_loras:
lora_name = os.path.splitext(lora)[0]
if lora_name not in prompt:
prompt = prompt + f"<lora:{lora_name}:1>"
contronet_enable = get_param_value(params_dict, 'controlnet_enable')
if contronet_enable:
controlnet_module = get_param_value(params_dict, 'controlnet_preprocessor', defaultValue=None)
if len(selected_cn_model) < 1:
controlnet_model = "None"
else:
controlnet_model = os.path.splitext(selected_cn_model[0])[0]
controlnet_image = get_param_value(params_dict, 'txt2img_controlnet_ControlNet_input_image', defaultValue=None)
controlnet_image = controlnet_image.split(',')[1]
weight = float(get_param_value(params_dict, 'controlnet_weight', defaultValue=1)) #1,
if get_param_value(params_dict, 'controlnet_resize_mode_just_resize'):
resize_mode = "Just Resize" # "Crop and Resize",
if get_param_value(params_dict, 'controlnet_resize_mode_Crop_and_Resize'):
resize_mode = "Crop and Resize"
if get_param_value(params_dict, 'controlnet_resize_mode_Resize_and_Fill'):
resize_mode = "Resize and Fill"
lowvram = get_param_value(params_dict, 'controlnet_lowVRAM_enable') #: "False",
processor_res = int(get_param_value(params_dict, 'controlnet_preprocessor_resolution', defaultValue=512))
threshold_a = float(get_param_value(params_dict, 'controlnet_canny_low_threshold', defaultValue=0))
threshold_b = float(get_param_value(params_dict, 'controlnet_canny_high_threshold', defaultValue=1))
guidance_start = float(get_param_value(params_dict, 'controlnet_starting_control_step', defaultValue=0)) #: 0,
guidance_end = float(get_param_value(params_dict, 'controlnet_ending_control_step', defaultValue=1)) #: 1,
if get_param_value(params_dict, 'controlnet_control_mode_balanced'):
guessmode = "Balanced"
if get_param_value(params_dict, 'controlnet_control_mode_my_prompt_is_more_important'):
guessmode = "My prompt is more important"
if get_param_value(params_dict, 'controlnet_control_mode_controlnet_is_more_important'):
guessmode = "Controlnet is more important"
pixel_perfect = get_param_value(params_dict, 'controlnet_pixel_perfect')
allow_preview = get_param_value(params_dict, 'controlnet_allow_preview')
loopback = get_param_value(params_dict, 'controlnet_loopback_automatically')
endpoint_name = checkpoint_info['sagemaker_endpoint'] #"infer-endpoint-ca0e"
if contronet_enable:
print('txt2img with controlnet!!!!!!!!!!')
payload = {
"endpoint_name": endpoint_name,
"task": "text-to-image",
"username": "test",
"checkpoint_info":checkpoint_info,
"models":{
"space_free_size": 4e10,
"Stable-diffusion": selected_sd_model,
"ControlNet": selected_cn_model,
"hypernetworks": selected_hypernets,
"Lora": selected_loras,
"embeddings": selected_embeddings
},
"txt2img_payload":{
"enable_hr": enable_hr,
"denoising_strength": denoising_strength,
"firstphase_width": firstphase_width,
"firstphase_height": firstphase_height,
"hr_scale": hr_scale,
"hr_upscaler": hr_upscaler,
"hr_second_pass_steps": hr_second_pass_steps,
"hr_resize_x": hr_resize_x,
"hr_resize_y": hr_resize_y,
"prompt": prompt,
"styles": styles,
"seed": seed,
"subseed": subseed,
"subseed_strength": subseed_strength,
"seed_resize_from_h": seed_resize_from_h,
"seed_resize_from_w": seed_resize_from_w,
"sampler_index": sampler_index,
"batch_size": batch_size,
"n_iter": n_iter,
"steps": steps,
"cfg_scale": cfg_scale,
"width": width,
"height": height,
"restore_faces": restore_faces,
"tiling": tiling,
"negative_prompt": negative_prompt,
"eta": eta,
"s_churn": s_churn,
"s_tmax": s_tmax,
"s_tmin": s_tmin,
"s_noise": s_noise,
"override_settings": override_settings,
"script_name": script_name,
"script_args": script_args,
"alwayson_scripts":{
"controlnet":{
"args":[
{
"input_image": controlnet_image,
"mask": "",
"module": controlnet_module,
"model": controlnet_model,
"loopback": loopback,
"weight": weight,
"resize_mode": resize_mode,
"lowvram": lowvram,
"processor_res": processor_res,
"threshold_a": threshold_a,
"threshold_b": threshold_b,
"guidance_start": guidance_start,
"guidance_end": guidance_end,
"guessmode": guessmode,
"pixel_perfect": pixel_perfect
}
]
}
}
}
}
else:
print('txt2img ##########')
# construct payload
payload = {
"endpoint_name": endpoint_name,
"task": "text-to-image",
"username": "test",
"checkpoint_info":checkpoint_info,
"models":{
"space_free_size": 2e10,
"Stable-diffusion": selected_sd_model,
"ControlNet": [],
"hypernetworks": selected_hypernets,
"Lora": selected_loras,
"embeddings": selected_embeddings
},
"txt2img_payload": {
"enable_hr": enable_hr,
"denoising_strength": denoising_strength,
"firstphase_width": firstphase_width,
"firstphase_height": firstphase_height,
"hr_scale": hr_scale,
"hr_upscaler": hr_upscaler,
"hr_second_pass_steps": hr_second_pass_steps,
"hr_resize_x": hr_resize_x,
"hr_resize_y": hr_resize_y,
"prompt": prompt,
"styles": styles,
"seed": seed,
"subseed": subseed,
"subseed_strength": subseed_strength,
"seed_resize_from_h": seed_resize_from_h,
"seed_resize_from_w": seed_resize_from_w,
"sampler_index": sampler_index,
"batch_size": batch_size,
"n_iter": n_iter,
"steps": steps,
"cfg_scale": cfg_scale,
"width": width,
"height": height,
"restore_faces": restore_faces,
"tiling": tiling,
"negative_prompt": negative_prompt,
"eta": eta,
"s_churn": s_churn,
"s_tmax": s_tmax,
"s_tmin": s_tmin,
"s_noise": s_noise,
"override_settings": override_settings,
"script_name": script_name,
"script_args": script_args},
}
return payload
# Global exception capture
stepf_client = boto3.client('stepfunctions')
@app.get("/inference")
def root():
return {"message": const.SOLUTION_NAME}
@app.post("/inference/run-sagemaker-inference")
async def run_sagemaker_inference(request: Request):
try:
logger.info('entering the run_sage_maker_inference function!')
inference_id = get_uuid()
payload_checkpoint_info = await request.json()
print(f"!!!!!!!!!!input in json format {payload_checkpoint_info}")
task_type = payload_checkpoint_info.get('task_type')
print(f"Task Type: {task_type}")
params_dict = load_json_from_s3(S3_BUCKET_NAME, 'config/aigc.json')
logger.info(json.dumps(params_dict))
payload = json_convert_to_payload(params_dict, payload_checkpoint_info)
print(f"input in json format {payload}")
endpoint_name = payload["endpoint_name"]
predictor = Predictor(endpoint_name)
# adjust time out time to 1 hour
initial_args = {}
initial_args["InvocationTimeoutSeconds"]=3600
predictor = AsyncPredictor(predictor, name=endpoint_name)
predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()
prediction = predictor.predict_async(data=payload, initial_args=initial_args, inference_id=inference_id)
output_path = prediction.output_path
#put the item to inference DDB for later check status
current_time = str(datetime.now())
response = inference_table.put_item(
Item={
'InferenceJobId': inference_id,
'startTime': current_time,
'status': 'inprogress',
'taskType': task_type
})
print(f"output_path is {output_path}")
headers = {
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "OPTIONS,POST,GET"
}
response = JSONResponse(content={"inference_id": inference_id, "status": "inprogress", "endpoint_name": endpoint_name, "output_path": output_path}, headers=headers)
return response
except Exception as e:
logger.error(f"Error occurred: {str(e)}")
# raise HTTPException(status_code=500, detail=f"An error occurred during processing.{str(e)}")
headers = {
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "OPTIONS,POST,GET"
}
current_time = str(datetime.now())
response = inference_table.put_item(
Item={
'InferenceJobId': inference_id,
'startTime': current_time,
'completeTime': current_time,
'status': 'failure',
'taskType': task_type or "unknown",
'error': f"error info {str(e)}"}
)
response = JSONResponse(content={"inference_id": inference_id, "status":"failure", "error": f"error info {str(e)}"}, headers=headers)
return response
@app.post("/inference/deploy-sagemaker-endpoint")
async def deploy_sagemaker_endpoint(request: Request):
logger.info("entering the deploy_sagemaker_endpoint function!")
endpoint_deployment_id = get_uuid()
try:
payload = await request.json()
logger.info(f"input in json format {payload}")
payload['endpoint_deployment_id'] = endpoint_deployment_id
resp = stepf_client.start_execution(
stateMachineArn=STEP_FUNCTION_ARN,
input=json.dumps(payload)
)
#put the item to inference DDB for later check status
current_time = str(datetime.now())
response = endpoint_deployment_table.put_item(
Item={
'EndpointDeploymentJobId': endpoint_deployment_id,
'startTime': current_time,
'status': 'inprogress'
})
logger.info("trigger step-function with following response")
logger.info(f"finish trigger step function for deployment with output {resp}")
return 0
except Exception as e:
logger.error(f"error calling run-sagemaker-inference with exception: {e}")
#put the item to inference DDB for later check status
current_time = str(datetime.now())
response = endpoint_deployment_table.put_item(
Item={
'EndpointDeploymentJobId': endpoint_deployment_id,
'startTime': current_time,
'status': 'failed',
'completeTime': current_time,
'error': str(e)
})
return 0
@app.post("/inference/delete-sagemaker-endpoint")
async def delete_sagemaker_endpoint(request: Request):
logger.info("entering the delete_sagemaker_endpoint function!")
try:
payload = await request.json()
delete_endpoint_list = payload.get('delete_endpoint_list', [])
logger.info(f"delete endpoint list: {delete_endpoint_list}")
# delete sagemaker endpoints and update DynamoDB in the same loop
for endpoint in delete_endpoint_list:
try:
# check if endpoint exists
try:
response = sagemaker.describe_endpoint(EndpointName=endpoint)
print(response)
logger.info(f"Deleting endpoint: {endpoint}")
# If the endpoint exists and you want to delete it, you can do so here:
sagemaker.delete_endpoint(EndpointName=endpoint)
except (BotoCoreError, ClientError) as error:
if error.response['Error']['Code'] == 'ResourceNotFound':
print("Endpoint not found, no need to delete.")
else:
# Handle other potential errors
print(error)
# update DynamoDB status
resp = getEndpointDeployJob(endpoint)
if resp:
endpoint_deployment_job_id = resp['EndpointDeploymentJobId']
logger.info(f"Updating DynamoDB status for: {endpoint_deployment_job_id}")
endpoint_deployment_table.update_item(
Key={
'EndpointDeploymentJobId': endpoint_deployment_job_id
},
UpdateExpression="SET #s = :s",
ExpressionAttributeNames={
'#s': 'status'
},
ExpressionAttributeValues={
':s': 'deleted'
}
)
else:
resp = getEndpointDeployJob_with_endpoint_name(endpoint)
if resp:
endpoint_deployment_job_id = resp['EndpointDeploymentJobId']
logger.info(f"Updating DynamoDB status for: {endpoint_deployment_job_id}")
endpoint_deployment_table.update_item(
Key={
'EndpointDeploymentJobId': endpoint_deployment_job_id
},
UpdateExpression="SET #s = :s",
ExpressionAttributeNames={
'#s': 'status'
},
ExpressionAttributeValues={
':s': 'deleted'
}
)
else:
logger.error(f"No matching DynamoDB record found for endpoint: {endpoint}")
except ClientError as e:
if e.response['Error']['Code'] in ['ValidationException', 'ResourceNotFoundException']:
logger.error(f"Endpoint or DynamoDB item {endpoint} does not exist, skipping")
else:
raise
logger.info("Successfully processed endpoint deletions and status updates")
return 0
except Exception as e:
logger.error(f"error deleting sagemaker endpoint with exception: {e}")
return 0
@app.get("/inference/list-endpoint-deployment-jobs")
async def list_endpoint_deployment_jobs():
logger.info(f"entering list_endpoint_deployment_jobs")
return getEndpointDeploymentJobList()
@app.get("/inference/list-inference-jobs")
async def list_inference_jobs():
logger.info(f"entering list_endpoint_deployment_jobs")
return getInferenceJobList()
@app.get("/inference/get-endpoint-deployment-job")
async def get_endpoint_deployment_job(jobID: str = None):
logger.info(f"entering get_endpoint_deployment_job function ")
# endpoint_deployment_jobId = request.query_params
endpoint_deployment_jobId = jobID
logger.info(f"endpoint_deployment_jobId is {str(endpoint_deployment_jobId)}")
return getEndpointDeployJob(endpoint_deployment_jobId)
@app.get("/inference/get-inference-job")
async def get_inference_job(jobID: str = None):
inference_jobId = jobID
logger.info(f"entering get_inference_job function with jobId: {inference_jobId}")
try:
return getInferenceJob(inference_jobId)
except Exception as e:
logger.error(f"Error getting inference job: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/inference/get-inference-job-image-output")
async def get_inference_job_image_output(jobID: str = None) -> List[str]:
inference_jobId = jobID
if inference_jobId is None or inference_jobId.strip() == "":
logger.info(f"jobId is empty string or None, just return empty string list")
return []
logger.info(f"Entering get_inference_job_image_output function with jobId: {inference_jobId}")
try:
job_record = getInferenceJob(inference_jobId)
except Exception as e:
logger.error(f"Error getting inference job: {str(e)}")
return []
# Assuming the job_record contains a list of image names
image_names = job_record.get("image_names", [])
presigned_urls = []
for image_name in image_names:
try:
image_key = f"out/{inference_jobId}/result/{image_name}"
presigned_url = generate_presigned_url(S3_BUCKET_NAME, image_key)
presigned_urls.append(presigned_url)
except Exception as e:
logger.error(f"Error generating presigned URL for image {image_name}: {str(e)}")
# Continue with the next image if this one fails
continue
return presigned_urls
@app.get("/inference/get-inference-job-param-output")
async def get_inference_job_param_output(jobID: str = None) -> List[str]:
inference_jobId = jobID
if inference_jobId is None or inference_jobId.strip() == "":
logger.info(f"jobId is empty string or None, just return empty string list")
return []
logger.info(f"Entering get_inference_job_param_output function with jobId: {inference_jobId}")
try:
job_record = getInferenceJob(inference_jobId)
except Exception as e:
logger.error(f"Error getting inference job: {str(e)}")
return []
presigned_url = ""
try:
json_key = f"out/{inference_jobId}/result/{inference_jobId}_param.json"
presigned_url = generate_presigned_url(S3_BUCKET_NAME, json_key)
except Exception as e:
logger.error(f"Error generating presigned URL: {str(e)}")
return []
return [presigned_url]
def generate_presigned_url(bucket_name: str, key: str, expiration=3600) -> str:
try:
response = s3.generate_presigned_url(
'get_object',
Params={'Bucket': bucket_name, 'Key': key},
ExpiresIn=expiration
)
except Exception as e:
logger.error(f"Error generating presigned URL: {e}")
raise
return response
@app.get("/inference/generate-s3-presigned-url-for-uploading")
async def generate_s3_presigned_url_for_uploading(s3_bucket_name: str = None, key: str = None):
if not s3_bucket_name:
s3_bucket_name = S3_BUCKET_NAME
if not key:
raise HTTPException(status_code=400, detail="Key parameter is required")
try:
presigned_url = s3.generate_presigned_url(
'put_object',
Params={
'Bucket': s3_bucket_name,
'Key': key,
'ContentType': 'text/plain;charset=UTF-8'
},
ExpiresIn=3600,
HttpMethod='PUT'
)
except Exception as e:
headers = {
"Access-Control-Allow-Headers": "*",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "OPTIONS,POST,GET,PUT"
}
return JSONResponse(content=str(e), status_code=500, headers=headers)
headers = {
"Access-Control-Allow-Headers": "*",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "OPTIONS,POST,GET,PUT"
}
response = JSONResponse(content=presigned_url, headers=headers)
return response
@app.get("/inference/get-texual-inversion-list")
async def get_texual_inversion_list():
logger.info(f"entering get_texual_inversion_list()")
return get_s3_objects(S3_BUCKET_NAME,'texual_inversion')
@app.get("/inference/get-lora-list")
async def get_lora_list():
return get_s3_objects(S3_BUCKET_NAME,'lora')
@app.get("/inference/get-hypernetwork-list")
async def get_hypernetwork_list():
return get_s3_objects(S3_BUCKET_NAME,'hypernetwork')
@app.get("/inference/get-controlnet-model-list")
async def get_controlnet_model_list():
return get_s3_objects(S3_BUCKET_NAME,'controlnet')
@app.post("/inference/run-model-merge")
async def run_model_merge(request: Request):
try:
logger.info('entering the run_model_merge function!')
# TODO: add logic for inference id
merge_id = get_uuid()
payload_checkpoint_info = await request.json()
print(f"!!!!!!!!!!input in json format {payload_checkpoint_info}")
params_dict = load_json_from_s3(S3_BUCKET_NAME, 'config/aigc.json')
logger.info(json.dumps(params_dict))
payload = json_convert_to_payload(params_dict, payload_checkpoint_info)
print(f"input in json format {payload}")
endpoint_name = payload["endpoint_name"]
predictor = Predictor(endpoint_name)
predictor = AsyncPredictor(predictor, name=endpoint_name)
predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()
prediction = predictor.predict_async(data=payload, inference_id=inference_id)
output_path = prediction.output_path
#put the item to inference DDB for later check status
current_time = str(datetime.now())
response = inference_table.put_item(
Item={
'InferenceJobId': inference_id,
'startTime': current_time,
'status': 'inprogress'
})
print(f"output_path is {output_path}")
headers = {
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "OPTIONS,POST,GET"
}
response = JSONResponse(content={"inference_id": inference_id, "status": "inprogress", "endpoint_name": endpoint_name, "output_path": output_path}, headers=headers)
#response = JSONResponse(content={"inference_id": '6fa743f0-cb7a-496f-8205-dbd67df08be2', "status": "succeed", "output_path": ""}, headers=headers)
return response
except Exception as e:
logger.error(f"Error occurred: {str(e)}")
# raise HTTPException(status_code=500, detail=f"An error occurred during processing.{str(e)}")
headers = {
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "OPTIONS,POST,GET"
}
response = JSONResponse(content={"inference_id": inference_id, "status":"failure", "error": f"error info {str(e)}"}, headers=headers)
return response
#app.include_router(search) TODO: adding sub router for future
handler = Mangum(app)
add_pagination(app)