543 lines
19 KiB
Python
543 lines
19 KiB
Python
import json
|
|
import logging.config
|
|
import logging.config
|
|
import os
|
|
import traceback
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
import boto3
|
|
from boto3.dynamodb.conditions import Attr, Key
|
|
from botocore.exceptions import ClientError
|
|
from fastapi import FastAPI, Request, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi_pagination import add_pagination
|
|
from mangum import Mangum
|
|
from sagemaker.deserializers import JSONDeserializer
|
|
from sagemaker.predictor import Predictor
|
|
from sagemaker.predictor_async import AsyncPredictor
|
|
from sagemaker.serializers import JSONSerializer
|
|
|
|
from common.constant import const
|
|
from parse.parameter_parser import json_convert_to_payload
|
|
|
|
logging.config.fileConfig('logging.conf', disable_existing_loggers=False)
|
|
logger = logging.getLogger(const.LOGGER_API)
|
|
|
|
INFERENCE_JOB_TABLE = os.environ.get('INFERENCE_JOB_TABLE')
|
|
DDB_TRAINING_TABLE_NAME = os.environ.get('DDB_TRAINING_TABLE_NAME')
|
|
DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME = os.environ.get('DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME')
|
|
REGION_NAME = os.environ['AWS_REGION']
|
|
S3_BUCKET_NAME = os.environ.get('S3_BUCKET')
|
|
|
|
ddb_client = boto3.resource('dynamodb')
|
|
s3 = boto3.client('s3', region_name=REGION_NAME)
|
|
sagemaker = boto3.client('sagemaker')
|
|
inference_table = ddb_client.Table(INFERENCE_JOB_TABLE)
|
|
endpoint_deployment_table = ddb_client.Table(DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME)
|
|
|
|
|
|
async def custom_exception_handler(request: Request, exc: HTTPException):
|
|
headers = {
|
|
"Access-Control-Allow-Headers": "Content-Type",
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Methods": "OPTIONS,POST,GET"
|
|
}
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content={"detail": exc.detail},
|
|
headers=headers
|
|
)
|
|
|
|
|
|
app = FastAPI(
|
|
title="API List of SageMaker Inference",
|
|
version="0.9",
|
|
)
|
|
app.exception_handler(HTTPException)(custom_exception_handler)
|
|
|
|
|
|
def get_uuid():
|
|
uuid_str = str(uuid.uuid4())
|
|
return uuid_str
|
|
|
|
|
|
def getInferenceJobList():
|
|
response = inference_table.scan()
|
|
logger.info(f"inference job list response is {str(response)}")
|
|
return response['Items']
|
|
|
|
|
|
def build_filter_expression(end_time, endpoint, start_time, status, task_type):
|
|
filter_expression = None
|
|
if status:
|
|
filter_expression = Attr('status').eq(status)
|
|
if task_type:
|
|
if filter_expression:
|
|
filter_expression &= Attr('taskType').eq(task_type)
|
|
else:
|
|
filter_expression = Attr('taskType').eq(task_type)
|
|
if start_time:
|
|
if filter_expression:
|
|
filter_expression &= Attr('startTime').gte(start_time)
|
|
else:
|
|
filter_expression = Attr('startTime').gte(start_time)
|
|
if end_time:
|
|
if filter_expression:
|
|
filter_expression &= Attr('startTime').lte(end_time)
|
|
else:
|
|
filter_expression = Attr('startTime').lte(end_time)
|
|
if endpoint:
|
|
if filter_expression:
|
|
filter_expression &= Attr('params.sagemaker_inference_endpoint_name').eq(endpoint)
|
|
else:
|
|
filter_expression = Attr('params.sagemaker_inference_endpoint_name').eq(endpoint)
|
|
return filter_expression
|
|
|
|
|
|
def query_inference_job_list(status: str, task_type: str, start_time: str, end_time: str,
|
|
endpoint: str, checkpoint: str, limit: int):
|
|
print(f"query_inference_job_list params are:{status},{task_type},{start_time},{end_time},{checkpoint},{endpoint}")
|
|
try:
|
|
response = None
|
|
filter_expression = build_filter_expression(end_time, endpoint, start_time, status, task_type)
|
|
if limit != const.PAGE_LIMIT_ALL and limit <= 0:
|
|
logger.info(f"query inference job list error because of limit <0 {limit}")
|
|
return ""
|
|
if filter_expression:
|
|
response = inference_table.scan(
|
|
FilterExpression=filter_expression
|
|
)
|
|
else:
|
|
response = inference_table.scan()
|
|
logger.info(f"query inference job list response is {str(response)}")
|
|
if response:
|
|
return filter_checkpoint_items(limit, checkpoint, response['Items'])
|
|
return response
|
|
except Exception as e:
|
|
logger.info(f"query inference job list error ")
|
|
logger.info(e)
|
|
return ""
|
|
|
|
|
|
def sort_by_start_time(item):
|
|
return item.get("startTime", "")
|
|
|
|
|
|
def filter_checkpoint_items(limit, checkpoint, items):
|
|
items = sorted(items, key=sort_by_start_time, reverse=True)
|
|
if checkpoint:
|
|
filtered_data = []
|
|
for item in items:
|
|
if "params" in item and "used_models" in item["params"]:
|
|
used_models = item["params"]["used_models"].get("Stable-diffusion", [])
|
|
for model in used_models:
|
|
if "model_name" in model and model["model_name"] == checkpoint:
|
|
filtered_data.append(item)
|
|
if limit == const.PAGE_LIMIT_ALL:
|
|
return filtered_data
|
|
else:
|
|
if len(filtered_data) >= limit:
|
|
return filtered_data[0: limit]
|
|
else:
|
|
return filtered_data
|
|
if limit == const.PAGE_LIMIT_ALL:
|
|
return items
|
|
else:
|
|
if len(items) >= limit:
|
|
return items[0: limit]
|
|
else:
|
|
return items
|
|
|
|
|
|
def getInferenceJob(inference_job_id):
|
|
if not inference_job_id:
|
|
logger.error("Invalid inference job id")
|
|
raise ValueError("Inference job id must not be None or empty")
|
|
|
|
try:
|
|
resp = inference_table.query(
|
|
KeyConditionExpression=Key('InferenceJobId').eq(inference_job_id)
|
|
)
|
|
# logger.info(resp)
|
|
record_list = resp['Items']
|
|
if len(record_list) == 0:
|
|
logger.error(f"No inference job info item for id: {inference_job_id}")
|
|
raise ValueError(f"There is no inference job info item for id: {inference_job_id}")
|
|
return record_list[0]
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Exception occurred when trying to query inference job with id: {inference_job_id}, exception is {str(e)}")
|
|
raise
|
|
|
|
|
|
def getEndpointDeploymentJobList():
|
|
try:
|
|
sagemaker = boto3.client('sagemaker')
|
|
ddb = boto3.resource('dynamodb')
|
|
endpoint_deployment_table = ddb.Table(DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME)
|
|
|
|
response = endpoint_deployment_table.scan()
|
|
logger.info(f"endpoint deployment job list response is {str(response)}")
|
|
|
|
# Get the list of SageMaker endpoints
|
|
list_results = sagemaker.list_endpoints()
|
|
sagemaker_endpoints = [ep_info['EndpointName'] for ep_info in list_results['Endpoints']]
|
|
logger.info(str(sagemaker_endpoints))
|
|
|
|
# Filter the endpoint job list
|
|
filtered_endpoint_jobs = []
|
|
for job in response['Items']:
|
|
if 'endpoint_name' in job:
|
|
endpoint_name = job['endpoint_name']
|
|
deployment_job_id = job['EndpointDeploymentJobId']
|
|
|
|
if endpoint_name in sagemaker_endpoints:
|
|
filtered_endpoint_jobs.append(job)
|
|
else:
|
|
# Remove the job item from the DynamoDB table if the endpoint doesn't exist in SageMaker
|
|
endpoint_deployment_table.delete_item(Key={'EndpointDeploymentJobId': deployment_job_id})
|
|
else:
|
|
filtered_endpoint_jobs.append(job)
|
|
|
|
return filtered_endpoint_jobs
|
|
|
|
except ClientError as e:
|
|
print(f"An error occurred: {e}")
|
|
return []
|
|
|
|
except Exception as e:
|
|
print(f"An unexpected error occurred: {e}")
|
|
return []
|
|
|
|
|
|
def getEndpointDeployJob(endpoint_deploy_job_id):
|
|
try:
|
|
resp = endpoint_deployment_table.query(
|
|
KeyConditionExpression=Key('EndpointDeploymentJobId').eq(endpoint_deploy_job_id)
|
|
)
|
|
logger.info(resp)
|
|
except Exception as e:
|
|
logger.error(e)
|
|
record_list = resp['Items']
|
|
if len(record_list) == 0:
|
|
logger.error("There is no endpoint deployment job info item for id:" + endpoint_deploy_job_id)
|
|
return {}
|
|
return record_list[0]
|
|
|
|
|
|
def getEndpointDeployJob_with_endpoint_name(endpoint_name):
|
|
try:
|
|
resp = endpoint_deployment_table.scan(
|
|
FilterExpression=Attr('endpoint_name').eq(endpoint_name)
|
|
)
|
|
logger.info(resp)
|
|
except Exception as e:
|
|
logger.error(e)
|
|
|
|
record_list = resp['Items']
|
|
if len(record_list) == 0:
|
|
logger.error("There is no endpoint deployment job info item with endpoint name:" + endpoint_name)
|
|
return {}
|
|
|
|
return record_list[0]
|
|
|
|
|
|
def get_s3_objects(bucket_name, folder_name):
|
|
# Ensure the folder name ends with a slash
|
|
if not folder_name.endswith('/'):
|
|
folder_name += '/'
|
|
|
|
# List objects in the specified bucket and folder
|
|
response = s3.list_objects_v2(Bucket=bucket_name, Prefix=folder_name)
|
|
|
|
# Extract object names from the response
|
|
object_names = [obj['Key'][len(folder_name):] for obj in response.get('Contents', []) if obj['Key'] != folder_name]
|
|
|
|
return object_names
|
|
|
|
|
|
def load_json_from_s3(bucket_name, key):
|
|
# Get the JSON file from the specified bucket and key
|
|
response = s3.get_object(Bucket=bucket_name, Key=key)
|
|
json_file = response['Body'].read().decode('utf-8')
|
|
|
|
# Load the JSON file into a dictionary
|
|
data = json.loads(json_file)
|
|
|
|
return data
|
|
|
|
|
|
# Global exception capture
|
|
stepf_client = boto3.client('stepfunctions')
|
|
|
|
|
|
# def get_curent_time():
|
|
# # Get the current time
|
|
# now = datetime.now()
|
|
# formatted_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
|
# return formatted_time
|
|
|
|
@app.post("/inference/run-sagemaker-inference")
|
|
@app.post("/inference-api/inference")
|
|
async def run_sagemaker_inference(request: Request):
|
|
try:
|
|
logger.info('entering the run_sage_maker_inference function!')
|
|
|
|
inference_id = get_uuid()
|
|
|
|
payload_checkpoint_info = await request.json()
|
|
print(f"!!!!!!!!!!input in json format {payload_checkpoint_info}")
|
|
task_type = payload_checkpoint_info.get('task_type')
|
|
print(f"Task Type: {task_type}")
|
|
path = request.url.path
|
|
logger.info(f'Path: {path}')
|
|
if path == '/inference-api/inference':
|
|
# Invoke by API
|
|
logger.info('invoked by api')
|
|
params_dict = load_json_from_s3(S3_BUCKET_NAME, 'template/inferenceTemplate.json')
|
|
else:
|
|
# Invoke by UI
|
|
params_dict = load_json_from_s3(S3_BUCKET_NAME, 'config/aigc.json')
|
|
# logger.info(json.dumps(params_dict))
|
|
payload = json_convert_to_payload(params_dict, payload_checkpoint_info, task_type)
|
|
print(f"input in json format:")
|
|
checkpoint_name = None
|
|
if task_type == 'img2img':
|
|
checkpoint_name = params_dict['img2img_sagemaker_stable_diffusion_checkpoint']
|
|
elif task_type == 'txt2img':
|
|
checkpoint_name = params_dict['txt2img_sagemaker_stable_diffusion_checkpoint']
|
|
|
|
def show_slim_dict(payload):
|
|
pay_type = type(payload)
|
|
if pay_type is dict:
|
|
for k, v in payload.items():
|
|
print(f"{k}")
|
|
show_slim_dict(v)
|
|
elif pay_type is list:
|
|
for v in payload:
|
|
print(f"list")
|
|
show_slim_dict(v)
|
|
elif pay_type is str:
|
|
if len(payload) > 100:
|
|
print(f" : {len(payload)} contents")
|
|
else:
|
|
print(f" : {payload}")
|
|
else:
|
|
print(f" : {payload}")
|
|
|
|
show_slim_dict(payload)
|
|
|
|
endpoint_name = payload["endpoint_name"]
|
|
|
|
predictor = Predictor(endpoint_name)
|
|
|
|
# adjust time out time to 1 hour
|
|
initial_args = {"InvocationTimeoutSeconds": 3600}
|
|
|
|
predictor = AsyncPredictor(predictor, name=endpoint_name)
|
|
predictor.serializer = JSONSerializer()
|
|
predictor.deserializer = JSONDeserializer()
|
|
prediction = predictor.predict_async(data=payload, initial_args=initial_args, inference_id=inference_id)
|
|
output_path = prediction.output_path
|
|
|
|
# put the item to inference DDB for later check status
|
|
current_time = str(datetime.now())
|
|
response = inference_table.put_item(
|
|
Item={
|
|
'InferenceJobId': inference_id,
|
|
'startTime': current_time,
|
|
'status': 'inprogress',
|
|
'endpoint': endpoint_name,
|
|
'checkpoint': checkpoint_name,
|
|
'taskType': task_type
|
|
})
|
|
print(f"output_path is {output_path}")
|
|
|
|
headers = {
|
|
"Access-Control-Allow-Headers": "Content-Type",
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Methods": "OPTIONS,POST,GET"
|
|
}
|
|
|
|
response = JSONResponse(
|
|
content={"inference_id": inference_id, "status": "inprogress", "endpoint_name": endpoint_name,
|
|
"output_path": output_path}, headers=headers)
|
|
return response
|
|
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
logger.error(f"Error occurred: {str(e)}")
|
|
|
|
# raise HTTPException(status_code=500, detail=f"An error occurred during processing.{str(e)}")
|
|
headers = {
|
|
"Access-Control-Allow-Headers": "Content-Type",
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Methods": "OPTIONS,POST,GET"
|
|
}
|
|
|
|
current_time = str(datetime.now())
|
|
response = inference_table.put_item(
|
|
Item={
|
|
'InferenceJobId': inference_id,
|
|
'startTime': current_time,
|
|
'completeTime': current_time,
|
|
'status': 'failure',
|
|
'endpoint': endpoint_name,
|
|
'checkpoint': checkpoint_name,
|
|
'taskType': task_type or "unknown",
|
|
'error': f"error info {str(e)}"}
|
|
)
|
|
|
|
response = JSONResponse(
|
|
content={"inference_id": inference_id, "status": "failure", "error": f"error info {str(e)}"},
|
|
headers=headers)
|
|
return response
|
|
|
|
|
|
# todo will remove
|
|
@app.post("/inference/query-inference-jobs")
|
|
async def query_inference_jobs(request: Request):
|
|
logger.info(f"entering query-inference-jobs")
|
|
query_params = await request.json()
|
|
logger.info(query_params)
|
|
status = query_params.get('status')
|
|
task_type = query_params.get('task_type')
|
|
start_time = query_params.get('start_time')
|
|
end_time = query_params.get('end_time')
|
|
endpoint = query_params.get('endpoint')
|
|
checkpoint = query_params.get('checkpoint')
|
|
limit = query_params.get("limit") if query_params.get("limit") else const.PAGE_LIMIT_ALL
|
|
logger.info(
|
|
f"entering query-inference-jobs {status},{task_type},{start_time},{end_time},{checkpoint},{endpoint},{limit}")
|
|
return query_inference_job_list(status, task_type, start_time, end_time, endpoint, checkpoint, limit)
|
|
|
|
|
|
def generate_presigned_url(bucket_name: str, key: str, expiration=3600) -> str:
|
|
try:
|
|
response = s3.generate_presigned_url(
|
|
'get_object',
|
|
Params={'Bucket': bucket_name, 'Key': key},
|
|
ExpiresIn=expiration
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error generating presigned URL: {e}")
|
|
raise
|
|
|
|
return response
|
|
|
|
|
|
@app.get("/inference/generate-s3-presigned-url-for-uploading")
|
|
async def generate_s3_presigned_url_for_uploading(s3_bucket_name: str = None, key: str = None):
|
|
if not s3_bucket_name:
|
|
s3_bucket_name = S3_BUCKET_NAME
|
|
|
|
if not key:
|
|
raise HTTPException(status_code=400, detail="Key parameter is required")
|
|
|
|
try:
|
|
presigned_url = s3.generate_presigned_url(
|
|
'put_object',
|
|
Params={
|
|
'Bucket': s3_bucket_name,
|
|
'Key': key,
|
|
'ContentType': 'text/plain;charset=UTF-8'
|
|
},
|
|
ExpiresIn=3600,
|
|
HttpMethod='PUT'
|
|
)
|
|
except Exception as e:
|
|
headers = {
|
|
"Access-Control-Allow-Headers": "*",
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Methods": "OPTIONS,POST,GET,PUT"
|
|
}
|
|
return JSONResponse(content=str(e), status_code=500, headers=headers)
|
|
|
|
headers = {
|
|
"Access-Control-Allow-Headers": "*",
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Methods": "OPTIONS,POST,GET,PUT"
|
|
}
|
|
|
|
response = JSONResponse(content=presigned_url, headers=headers)
|
|
|
|
return response
|
|
|
|
|
|
@app.post("/inference/run-model-merge")
|
|
async def run_model_merge(request: Request):
|
|
try:
|
|
logger.info('entering the run_model_merge function!')
|
|
|
|
# TODO: add logic for inference id
|
|
merge_id = get_uuid()
|
|
|
|
payload_checkpoint_info = await request.json()
|
|
print(f"!!!!!!!!!!input in json format {payload_checkpoint_info}")
|
|
|
|
params_dict = load_json_from_s3(S3_BUCKET_NAME, 'config/aigc.json')
|
|
|
|
logger.info(json.dumps(params_dict))
|
|
payload = json_convert_to_payload(params_dict, payload_checkpoint_info)
|
|
print(f"input in json format {payload}")
|
|
task_type = payload_checkpoint_info.get('task_type')
|
|
endpoint_name = payload["endpoint_name"]
|
|
checkpoint_name = None
|
|
if task_type == 'img2img':
|
|
checkpoint_name = params_dict['img2img_sagemaker_stable_diffusion_checkpoint']
|
|
elif task_type == 'txt2img':
|
|
checkpoint_name = params_dict['txt2img_sagemaker_stable_diffusion_checkpoint']
|
|
predictor = Predictor(endpoint_name)
|
|
|
|
predictor = AsyncPredictor(predictor, name=endpoint_name)
|
|
predictor.serializer = JSONSerializer()
|
|
predictor.deserializer = JSONDeserializer()
|
|
prediction = predictor.predict_async(data=payload, inference_id=inference_id)
|
|
output_path = prediction.output_path
|
|
|
|
# put the item to inference DDB for later check status
|
|
current_time = str(datetime.now())
|
|
response = inference_table.put_item(
|
|
Item={
|
|
'InferenceJobId': inference_id,
|
|
'startTime': current_time,
|
|
'status': 'inprogress',
|
|
'endpoint': endpoint_name,
|
|
'checkpoint': checkpoint_name,
|
|
})
|
|
print(f"output_path is {output_path}")
|
|
|
|
headers = {
|
|
"Access-Control-Allow-Headers": "Content-Type",
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Methods": "OPTIONS,POST,GET"
|
|
}
|
|
|
|
response = JSONResponse(
|
|
content={"inference_id": inference_id, "status": "inprogress", "endpoint_name": endpoint_name,
|
|
"output_path": output_path}, headers=headers)
|
|
# response = JSONResponse(content={"inference_id": '6fa743f0-cb7a-496f-8205-dbd67df08be2', "status": "succeed", "output_path": ""}, headers=headers)
|
|
return response
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error occurred: {str(e)}")
|
|
|
|
# raise HTTPException(status_code=500, detail=f"An error occurred during processing.{str(e)}")
|
|
headers = {
|
|
"Access-Control-Allow-Headers": "Content-Type",
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Methods": "OPTIONS,POST,GET"
|
|
}
|
|
|
|
response = JSONResponse(
|
|
content={"inference_id": inference_id, "status": "failure", "error": f"error info {str(e)}"},
|
|
headers=headers)
|
|
return response
|
|
|
|
|
|
# app.include_router(search) TODO: adding sub router for future
|
|
|
|
handler = Mangum(app)
|
|
add_pagination(app)
|