import json import logging import os import boto3 from inference_libs import parse_sagemaker_result, get_bucket_and_key, get_inference_job, send_message_to_sns from start_inference_job import update_inference_job_table s3_resource = boto3.resource('s3') SNS_TOPIC = os.environ['NOTICE_SNS_TOPIC'] logger = logging.getLogger(__name__) logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR) def handler(event, context): logger.info(json.dumps(event)) message = event['Records'][0]['Sns']['Message'] message = json.loads(message) if 'invocationStatus' not in message: # maybe a message from SNS for test logger.error("Not a valid sagemaker inference result message") return invocation_status = message["invocationStatus"] inference_id = message["inferenceId"] if invocation_status != "Completed": update_inference_job_table(inference_id, 'status', 'failed') update_inference_job_table(inference_id, 'sagemakerRaw', str(message)) print(f"Not complete invocation!") send_message_to_sns(message, SNS_TOPIC) return message endpoint_name = message["requestParameters"]["endpointName"] output_location = message["responseParameters"]["outputLocation"] bucket, key = get_bucket_and_key(output_location) obj = s3_resource.Object(bucket, key) body = obj.get()['Body'].read().decode('utf-8') print(f"Sagemaker Out Body: {body}") sagemaker_out = json.loads(body) if sagemaker_out is None: update_inference_job_table(inference_id, 'status', 'failed') message_json = { 'InferenceJobId': inference_id, 'status': "failed", 'reason': "Sagemaker inference invocation completed, but the sagemaker output failed to be parsed as json" } send_message_to_sns(message_json, SNS_TOPIC) raise ValueError("body contains invalid JSON") # Get the task type job = get_inference_job(inference_id) task_type = job.get('taskType', 'txt2img') parse_sagemaker_result(sagemaker_out, inference_id, task_type, endpoint_name)