stable-diffusion-aws-extension/middleware_api/lambda/inferences/result_notification/app.py

302 lines
11 KiB
Python

import base64
import io
import json
import logging
import os
from datetime import datetime
import boto3
from PIL import Image
from boto3.dynamodb.conditions import Key
from botocore.exceptions import ClientError
s3_resource = boto3.resource('s3')
s3_client = boto3.client('s3')
INFERENCE_JOB_TABLE = os.environ.get('INFERENCE_JOB_TABLE')
DDB_TRAINING_TABLE_NAME = os.environ.get('DDB_TRAINING_TABLE_NAME')
DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME = os.environ.get('DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME')
S3_BUCKET_NAME = os.environ.get('S3_BUCKET')
SNS_TOPIC = os.environ['NOTICE_SNS_TOPIC']
ddb_client = boto3.resource('dynamodb')
inference_table = ddb_client.Table(INFERENCE_JOB_TABLE)
endpoint_deployment_table = ddb_client.Table(DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME)
sns = boto3.client('sns')
logger = logging.getLogger()
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
def get_bucket_and_key(s3uri):
pos = s3uri.find('/', 5)
bucket = s3uri[5: pos]
key = s3uri[pos + 1:]
return bucket, key
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
return Image.open(io.BytesIO(base64.b64decode(encoding)))
def update_inference_job_table(inference_id, key, value):
# Update the inference DDB for the job status
response = inference_table.get_item(
Key={
"InferenceJobId": inference_id,
})
inference_resp = response['Item']
if not inference_resp:
raise Exception(f"Failed to get the inference job item with inference id: {inference_id}")
response = inference_table.update_item(
Key={
"InferenceJobId": inference_id,
},
UpdateExpression=f"set #k = :r",
ExpressionAttributeNames={'#k': key},
ExpressionAttributeValues={':r': value},
ReturnValues="UPDATED_NEW"
)
def getInferenceJob(inference_job_id):
if not inference_job_id:
logger.error("Invalid inference job id")
raise ValueError("Inference job id must not be None or empty")
try:
resp = inference_table.query(
KeyConditionExpression=Key('InferenceJobId').eq(inference_job_id)
)
(resp)
record_list = resp['Items']
if len(record_list) == 0:
logger.error(f"No inference job info item for id: {inference_job_id}")
raise ValueError(f"There is no inference job info item for id: {inference_job_id}")
return record_list[0]
except Exception as e:
logger.error(
f"Exception occurred when trying to query inference job with id: {inference_job_id}, exception is {str(e)}")
raise
def upload_file_to_s3(file_name, bucket, directory=None, object_name=None):
# If S3 object_name was not specified, use file_name
if object_name is None:
object_name = file_name
# Add the directory to the object_name
if directory:
object_name = f"{directory}/{object_name}"
# Upload the file
try:
s3_client.upload_file(file_name, bucket, object_name)
print(f"File {file_name} uploaded to {bucket}/{object_name}")
except Exception as e:
print(f"Error occurred while uploading {file_name} to {bucket}/{object_name}: {e}")
return False
return True
def get_topic_arn(sns, topic_name):
response = sns.list_topics()
for topic in response['Topics']:
if topic['TopicArn'].split(':')[-1] == topic_name:
return topic['TopicArn']
return None
def send_message_to_sns(message_json):
try:
message = message_json
sns_topic_name = SNS_TOPIC
sns_topic_arn = get_topic_arn(sns, sns_topic_name)
if sns_topic_arn is None:
print(f"No topic found with name {sns_topic_name}")
return {
'statusCode': 404,
'body': json.dumps('No topic found')
}
response = sns.publish(
TopicArn=sns_topic_arn,
Message=json.dumps(message),
Subject='Inference Error occurred Information',
)
print(f"Message sent to SNS topic: {sns_topic_name}")
return {
'statusCode': 200,
'body': json.dumps('Message sent successfully')
}
except ClientError as e:
print(f"Error sending message to SNS topic: {sns_topic_name}")
return {
'statusCode': 500,
'body': json.dumps('Error sending message'),
'error': str(e)
}
def lambda_handler(event, context):
print(json.dumps(event))
message = event['Records'][0]['Sns']['Message']
print("From SNS: " + str(message))
message = json.loads(message)
if 'invocationStatus' not in message:
# maybe a message from SNS for test
print("Not a valid message")
return
invocation_status = message["invocationStatus"]
inference_id = message["inferenceId"]
update_inference_job_table(inference_id, 'completeTime', str(datetime.now()))
if invocation_status == "Completed":
try:
print(f"Complete invocation!")
endpoint_name = message["requestParameters"]["endpointName"]
output_location = message["responseParameters"]["outputLocation"]
bucket, key = get_bucket_and_key(output_location)
obj = s3_resource.Object(bucket, key)
body = obj.get()['Body'].read().decode('utf-8')
print(f"Sagemaker Out Body: {body}")
json_body = json.loads(body)
if json_body is None:
update_inference_job_table(inference_id, 'status', 'failed')
message_json = {
'InferenceJobId': inference_id,
'status': "failed",
'reason': "Sagemaker inference invocation completed, but the sagemaker output failed to be parsed as json"
}
send_message_to_sns(message_json)
raise ValueError("body contains invalid JSON")
# Get the task type
job = getInferenceJob(inference_id)
taskType = job.get('taskType', 'txt2img')
if taskType in ["interrogate_clip", "interrogate_deepbooru"]:
caption = json_body['caption']
method = taskType
# Update the DynamoDB table for the caption
inference_table.update_item(
Key={
'InferenceJobId': inference_id
},
UpdateExpression='SET caption=:f',
ExpressionAttributeValues={
':f': caption,
}
)
elif taskType in ["txt2img", "img2img"]:
# save images
for count, b64image in enumerate(json_body["images"]):
image = decode_base64_to_image(b64image).convert("RGB")
output = io.BytesIO()
image.save(output, format="PNG")
# Upload the image to the S3 bucket
s3_client.put_object(
Body=output.getvalue(),
Bucket=S3_BUCKET_NAME,
Key=f"out/{inference_id}/result/image_{count}.png"
)
# Update the DynamoDB table
inference_table.update_item(
Key={
'InferenceJobId': inference_id
},
UpdateExpression='SET image_names = list_append(if_not_exists(image_names, :empty_list), :new_image)',
ExpressionAttributeValues={
':new_image': [f"image_{count}.png"],
':empty_list': []
}
)
# save parameters
inference_parameters = {}
inference_parameters["parameters"] = json_body["parameters"]
inference_parameters["info"] = json_body["info"]
inference_parameters["endpont_name"] = endpoint_name
inference_parameters["inference_id"] = inference_id
inference_parameters["sns_info"] = message
json_file_name = f"/tmp/{inference_id}_param.json"
with open(json_file_name, "w") as outfile:
json.dump(inference_parameters, outfile)
upload_file_to_s3(json_file_name, S3_BUCKET_NAME, f"out/{inference_id}/result",
f"{inference_id}_param.json")
update_inference_job_table(inference_id, 'inference_info_name', json_file_name)
print(f"Complete inference parameters {inference_parameters}")
elif taskType in ["extra-single-image", "rembg"]:
if 'image' not in json_body:
raise Exception(json_body)
# image = decode_base64_to_image(json_body["image"]).convert("RGB")
image = Image.open(io.BytesIO(base64.b64decode(json_body["image"])))
output = io.BytesIO()
image.save(output, format="PNG")
# Upload the image to the S3 bucket
s3_client.put_object(
Body=output.getvalue(),
Bucket=S3_BUCKET_NAME,
Key=f"out/{inference_id}/result/image.png"
)
# Update the DynamoDB table
inference_table.update_item(
Key={
'InferenceJobId': inference_id
},
UpdateExpression='SET image_names = list_append(if_not_exists(image_names, :empty_list), :new_image)',
ExpressionAttributeValues={
':new_image': [f"image.png"],
':empty_list': []
}
)
# save parameters
inference_parameters = {}
if taskType == "extra-single-image":
inference_parameters["html_info"] = json_body["html_info"]
inference_parameters["endpont_name"] = endpoint_name
inference_parameters["inference_id"] = inference_id
inference_parameters["sns_info"] = message
json_file_name = f"/tmp/{inference_id}_param.json"
with open(json_file_name, "w") as outfile:
json.dump(inference_parameters, outfile)
upload_file_to_s3(json_file_name, S3_BUCKET_NAME, f"out/{inference_id}/result",
f"{inference_id}_param.json")
update_inference_job_table(inference_id, 'inference_info_name', json_file_name)
print(f"Complete inference parameters {inference_parameters}")
update_inference_job_table(inference_id, 'status', 'succeed')
update_inference_job_table(inference_id, 'sagemakerRaw', str(message))
except Exception as e:
print(f"Error occurred: {str(e)}")
update_inference_job_table(inference_id, 'status', 'failed')
update_inference_job_table(inference_id, 'sagemakerRaw', json.dumps(json_body))
raise e
else:
update_inference_job_table(inference_id, 'status', 'failed')
update_inference_job_table(inference_id, 'sagemakerRaw', str(message))
print(f"Not complete invocation!")
send_message_to_sns(event['Records'][0]['Sns']['Message'])
return message