stable-diffusion-aws-extension/middleware_api/lambda/inferences/inference_async_events.py

67 lines
2.2 KiB
Python

import json
import logging
import os
import boto3
from aws_lambda_powertools import Tracer
from common.sns_util import send_message_to_sns
from inference_libs import parse_sagemaker_result, get_bucket_and_key, get_inference_job
from start_inference_job import update_inference_job_table
tracer = Tracer()
s3_resource = boto3.resource('s3')
SNS_TOPIC = os.environ['NOTICE_SNS_TOPIC']
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
@tracer.capture_lambda_handler
def handler(event, context):
logger.info(json.dumps(event))
message = event['Records'][0]['Sns']['Message']
message = json.loads(message)
if 'invocationStatus' not in message:
# maybe a message from SNS for test
logger.error("Not a valid sagemaker inference result message")
return
invocation_status = message["invocationStatus"]
inference_id = message["inferenceId"]
if invocation_status != "Completed":
update_inference_job_table(inference_id, 'status', 'failed')
update_inference_job_table(inference_id, 'sagemakerRaw', str(message))
print(f"Not complete invocation!")
send_message_to_sns(message, SNS_TOPIC)
return message
endpoint_name = message["requestParameters"]["endpointName"]
output_location = message["responseParameters"]["outputLocation"]
bucket, key = get_bucket_and_key(output_location)
obj = s3_resource.Object(bucket, key)
body = obj.get()['Body'].read().decode('utf-8')
print(f"Sagemaker Out Body: {body}")
sagemaker_out = json.loads(body)
if sagemaker_out is None:
update_inference_job_table(inference_id, 'status', 'failed')
message_json = {
'InferenceJobId': inference_id,
'status': "failed",
'reason': "Sagemaker inference invocation completed, but the sagemaker output failed to be parsed as json"
}
send_message_to_sns(message_json, SNS_TOPIC)
raise ValueError("body contains invalid JSON")
# Get the task type
job = get_inference_job(inference_id)
task_type = job.get('taskType', 'txt2img')
parse_sagemaker_result(sagemaker_out, inference_id, task_type, endpoint_name)