stable-diffusion-aws-extension/middleware_api/lambda/inference_result_notification/app.py

188 lines
6.6 KiB
Python

import json
import io
import boto3
import base64
import os
from PIL import Image
from datetime import datetime
from botocore.exceptions import ClientError
s3_resource = boto3.resource('s3')
s3_client = boto3.client('s3')
DDB_INFERENCE_TABLE_NAME = os.environ.get('DDB_INFERENCE_TABLE_NAME')
DDB_TRAINING_TABLE_NAME = os.environ.get('DDB_TRAINING_TABLE_NAME')
DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME = os.environ.get('DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME')
S3_BUCKET_NAME = os.environ.get('S3_BUCKET')
SNS_TOPIC = os.environ['NOTICE_SNS_TOPIC']
ddb_client = boto3.resource('dynamodb')
inference_table = ddb_client.Table(DDB_INFERENCE_TABLE_NAME)
endpoint_deployment_table = ddb_client.Table(DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME)
sns = boto3.client('sns')
def get_bucket_and_key(s3uri):
pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]
return bucket, key
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
return Image.open(io.BytesIO(base64.b64decode(encoding)))
def update_inference_job_table(inference_id, key, value):
# Update the inference DDB for the job status
response = inference_table.get_item(
Key={
"InferenceJobId": inference_id,
})
inference_resp = response['Item']
if not inference_resp:
raise Exception(f"Failed to get the inference job item with inference id: {inference_id}")
response = inference_table.update_item(
Key={
"InferenceJobId": inference_id,
},
UpdateExpression=f"set #k = :r",
ExpressionAttributeNames={'#k': key},
ExpressionAttributeValues={':r': value},
ReturnValues="UPDATED_NEW"
)
def get_curent_time():
# Get the current time
now = datetime.now()
formatted_time = now.strftime("%Y-%m-%d-%H-%M-%S")
return formatted_time
def upload_file_to_s3(file_name, bucket, directory=None, object_name=None):
# If S3 object_name was not specified, use file_name
if object_name is None:
object_name = file_name
# Add the directory to the object_name
if directory:
object_name = f"{directory}/{object_name}"
# Upload the file
try:
s3_client.upload_file(file_name, bucket, object_name)
print(f"File {file_name} uploaded to {bucket}/{object_name}")
except Exception as e:
print(f"Error occurred while uploading {file_name} to {bucket}/{object_name}: {e}")
return False
return True
def get_topic_arn(sns, topic_name):
response = sns.list_topics()
for topic in response['Topics']:
if topic['TopicArn'].split(':')[-1] == topic_name:
return topic['TopicArn']
return None
def send_message_to_sns(message_json):
try:
message = message_json
sns_topic_name = SNS_TOPIC
sns_topic_arn = get_topic_arn(sns, sns_topic_name)
if sns_topic_arn is None:
print(f"No topic found with name {sns_topic_name}")
return {
'statusCode': 404,
'body': json.dumps('No topic found')
}
response = sns.publish(
TopicArn=sns_topic_arn,
Message=json.dumps(message),
Subject='Inference Error occurred Information',
)
print(f"Message sent to SNS topic: {sns_topic_name}")
return {
'statusCode': 200,
'body': json.dumps('Message sent successfully')
}
except ClientError as e:
print(f"Error sending message to SNS topic: {sns_topic_name}")
return {
'statusCode': 500,
'body': json.dumps('Error sending message'),
'error': str(e)
}
def lambda_handler(event, context):
print("Received event: " + json.dumps(event, indent=2))
message = event['Records'][0]['Sns']['Message']
print("From SNS: " + str(message))
message = json.loads(message)
invocation_status = message["invocationStatus"]
inference_id = message["inferenceId"]
if invocation_status == "Completed":
print(f"Complete invocation!")
endpoint_name = message["requestParameters"]["endpointName"]
update_inference_job_table(inference_id, 'status', 'succeed')
update_inference_job_table(inference_id, 'completeTime', get_curent_time())
update_inference_job_table(inference_id, 'sagemakerRaw', str(message))
output_location = message["responseParameters"]["outputLocation"]
bucket, key = get_bucket_and_key(output_location)
obj = s3_resource.Object(bucket, key)
body = obj.get()['Body'].read().decode('utf-8')
json_body = json.loads(body)
# save images
for count, b64image in enumerate(json_body["images"]):
image = decode_base64_to_image(b64image).convert("RGB")
output = io.BytesIO()
image.save(output, format="JPEG")
# Upload the image to the S3 bucket
s3_client.put_object(
Body=output.getvalue(),
Bucket=S3_BUCKET_NAME,
Key=f"out/{inference_id}/result/image_{count}.jpg"
)
# Update the DynamoDB table
inference_table.update_item(
Key={
'InferenceJobId': inference_id
},
UpdateExpression='SET image_names = list_append(if_not_exists(image_names, :empty_list), :new_image)',
ExpressionAttributeValues={
':new_image': [f"image_{count}.jpg"],
':empty_list': []
}
)
# save parameters
inference_parameters = {}
inference_parameters["parameters"] = json_body["parameters"]
inference_parameters["info"] = json_body["info"]
inference_parameters["endpont_name"] = endpoint_name
inference_parameters["inference_id"] = inference_id
inference_parameters["sns_info"] = message
json_file_name = f"/tmp/{inference_id}_param.json"
with open(json_file_name, "w") as outfile:
json.dump(inference_parameters, outfile)
upload_file_to_s3(json_file_name, S3_BUCKET_NAME, f"out/{inference_id}/result",f"{inference_id}_param.json")
update_inference_job_table(inference_id, 'inference_info_name', json_file_name)
print(f"Complete inference parameters {inference_parameters}")
else:
update_inference_job_table(inference_id, 'status', 'failed')
update_inference_job_table(inference_id, 'sagemakerRaw', str(message))
print(f"Not complete invocation!")
send_message_to_sns(event['Records'][0]['Sns']['Message'])
return message