stable-diffusion-aws-extension/middleware_api/lambda/comfy/execute_async_events.py

84 lines
2.6 KiB
Python

import json
import logging
import os
from datetime import datetime
import boto3
from aws_lambda_powertools import Tracer
from common.ddb_service.client import DynamoDbUtilsService
from common.util import s3_scan_files, load_json_from_s3
from libs.comfy_data_types import InferenceResult
tracer = Tracer()
s3_resource = boto3.resource('s3')
sns_topic = os.environ['NOTICE_SNS_TOPIC']
bucket_name = os.environ['S3_BUCKET_NAME']
job_table = os.environ['INFERENCE_JOB_TABLE']
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
ddb_service = DynamoDbUtilsService(logger=logger)
ddb_client = boto3.resource('dynamodb')
inference_table = ddb_client.Table(job_table)
@tracer.capture_lambda_handler
def handler(event, context):
logger.info(json.dumps(event))
message = event['Records'][0]['Sns']['Message']
message = json.loads(message)
logger.info(message)
if 'invocationStatus' not in message:
# maybe a message from SNS for test
logger.error("Not a valid sagemaker inference result message")
return
result = load_json_from_s3(message['responseParameters']['outputLocation'])
logger.info(result)
result = {
"prompt_id": '11111111-1111-1111',
"instance_id": 'esd-real-time-test-rgihbd',
"status": 'success',
"output_path": f's3://{bucket_name}/images/',
"temp_path": f's3://{bucket_name}/images/'
}
result = InferenceResult(**result)
result = s3_scan_files(result)
if message["invocationStatus"] != "Completed":
result.status = "failed"
logger.info(result)
update_inference_job_table(prompt_id=result.prompt_id, key="status", value=result.status)
update_inference_job_table(prompt_id=result.prompt_id, key="output_path", value=result.output_path)
update_inference_job_table(prompt_id=result.prompt_id, key="output_files", value=result.output_files)
update_inference_job_table(prompt_id=result.prompt_id, key="temp_path", value=result.temp_path)
update_inference_job_table(prompt_id=result.prompt_id, key="temp_files", value=result.temp_files)
update_inference_job_table(prompt_id=result.prompt_id, key="complete_time", value=datetime.now().isoformat())
return {}
def update_inference_job_table(prompt_id, key, value):
logger.info(f"Update inference job table with prompt_id: {prompt_id}, key: {key}, value: {value}")
inference_table.update_item(
Key={
"prompt_id": prompt_id,
},
UpdateExpression=f"set #k = :r",
ExpressionAttributeNames={'#k': key},
ExpressionAttributeValues={':r': value},
ReturnValues="UPDATED_NEW"
)