stable-diffusion-aws-extension/middleware_api/comfy/merge_execute.py

119 lines
4.8 KiB
Python

import json
import logging
import os
import uuid
from decimal import Decimal
import boto3
from aws_lambda_powertools import Tracer
from common.ddb_service.client import DynamoDbUtilsService
from common.response import ok
from libs.utils import response_error
from execute import async_inference
from execute import ComfyExecuteTable
tracer = Tracer()
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
sqs = boto3.client('sqs')
sqs_url = os.environ.get('MERGE_SQS_URL')
execute_table_name = os.environ.get('EXECUTE_TABLE')
ddb_service = DynamoDbUtilsService(logger=logger)
dynamodb = boto3.resource('dynamodb')
execute_table = dynamodb.Table(execute_table_name)
def update_execute_job_table(prompt_id, key, value):
logger.info(f"Update execute table with prompt_id: {prompt_id}, key: {key}, value: {value}")
try:
execute_table.update_item(
Key={
"prompt_id": prompt_id,
},
UpdateExpression=f"set #k = :r",
ExpressionAttributeNames={'#k': key},
ExpressionAttributeValues={':r': value},
ConditionExpression="attribute_exists(prompt_id)",
ReturnValues="UPDATED_NEW"
)
except Exception as e:
logger.error(f"Update execute job table error: {e}")
raise e
def convert_float_to_decimal(data):
if isinstance(data, float):
return Decimal(str(data))
elif isinstance(data, dict):
return {key: convert_float_to_decimal(value) for key, value in data.items()}
elif isinstance(data, list):
return [convert_float_to_decimal(item) for item in data]
else:
return data
@tracer.capture_lambda_handler
def handler(raw_event, ctx):
logger.info(f"receive execute reqs start... Received event: {raw_event}")
try:
if 'Records' not in raw_event or not raw_event['Records']:
logger.error("ignore empty records msg")
return ok()
execute_merge_req = {}
execute_merge_req_batch_id = {}
# batch_save_items = []
for message in raw_event['Records']:
if not message or 'body' not in message:
logger.error("ignore empty msg")
return ok()
if not message['body']:
logger.error("ignore empty body msg")
return ok()
merge_job = json.loads(message['body'])
logger.info(F'merge job: {merge_job}')
inference_job = ComfyExecuteTable(**merge_job["save_item"])
event = merge_job["event"]
inference_id = merge_job["inference_id"]
# batch_save_items.append(inference_job)
endpoint_name = inference_job.endpoint_name
if (endpoint_name in execute_merge_req and execute_merge_req.get(endpoint_name)
and len(execute_merge_req.get(endpoint_name)) > 0):
execute_merge_req.get(endpoint_name).append(event)
else:
batch_id = str(uuid.uuid4())
execute_merge_req[endpoint_name] = [event]
execute_merge_req_batch_id[endpoint_name] = batch_id
inference_job.batch_id = execute_merge_req_batch_id.get(endpoint_name)
logger.info(F'update inference job batch_id: {inference_job.batch_id}, prompt_id: {inference_job.prompt_id}')
update_execute_job_table(prompt_id=inference_job.prompt_id, key="batch_id", value=inference_job.batch_id)
# logger.info(F'save inference job: {inference_job.__dict__}')
# ddb_service.put_items(execute_table, entries=inference_job.__dict__)
for key, vals in execute_merge_req.items():
resp = async_inference(execute_merge_req.get(key), execute_merge_req_batch_id.get(key), key)
# TODO status check and save
logger.info(f"batch async inference response: {resp}")
# resp1 = async_inference(execute_merge_req.get(key), execute_merge_req_batch_id.get(key) + "111", key)
# logger.info(f"batch async inference multi 11111 test response: {resp1}")
# resp2 = async_inference(execute_merge_req.get(key), execute_merge_req_batch_id.get(key) + "222", key)
# logger.info(f"batch async inference multi 22222 test response: {resp2}")
# resp3 = async_inference(execute_merge_req.get(key), execute_merge_req_batch_id.get(key) + "333", key)
# logger.info(f"batch async inference multi 3333 test response: {resp3}")
# resp4 = async_inference(execute_merge_req.get(key), execute_merge_req_batch_id.get(key) + "444", key)
# logger.info(f"batch async inference multi 4444 test response: {resp4}")
# batch_put_items(execute_table, convert_float_to_decimal(batch_save_items))
logger.info("receive execute reqs end...")
return ok()
except Exception as e:
return response_error(e)