import base64 import json import logging import os import uuid from dataclasses import dataclass from datetime import datetime from typing import Optional import boto3 from aws_lambda_powertools import Tracer from sagemaker import Predictor from sagemaker.deserializers import JSONDeserializer from sagemaker.predictor_async import AsyncPredictor from sagemaker.serializers import JSONSerializer from common.ddb_service.client import DynamoDbUtilsService from common.excepts import BadRequestException from common.response import ok, created from common.util import s3_scan_files, generate_presigned_url_for_keys, \ record_latency_metrics, record_count_metrics, get_workflow_name from libs.comfy_data_types import ComfyExecuteTable, InferenceResult from libs.enums import ComfyExecuteType, EndpointStatus, ServiceType from libs.utils import get_endpoint_by_name, response_error, get_endpoint_name_by_workflow_name tracer = Tracer() region = os.environ.get('AWS_REGION') bucket_name = os.environ.get('S3_BUCKET_NAME') execute_table = os.environ.get('EXECUTE_TABLE') logger = logging.getLogger(__name__) logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR) endpoint_instance_id = os.environ.get('ENDPOINT_INSTANCE_ID') ddb_service = DynamoDbUtilsService(logger=logger) sqs_url = os.environ.get('MERGE_SQS_URL') index_name = "endpoint_name-startTime-index" predictors = {} multi_gpu_instance_type_list = ['ml.p5.48xlarge', 'ml.p4d.24xlarge', 'ml.p3.8xlarge', 'ml.p3.16xlarge', 'ml.p3dn.24xlarge', 'ml.p2.8xlarge', 'ml.p2.16xlarge', 'ml.g4dn.12xlarge', 'ml.g5.12xlarge', 'ml.g5.24xlarge', 'ml.g5.48xlarge'] @dataclass class PrepareProps: need_reboot: Optional[bool] = False prepare_type: Optional[str] = "inputs" s3_source_path: Optional[str] = None local_target_path: Optional[str] = None sync_script: Optional[str] = None @dataclass class ExecuteEvent: prompt_id: str prompt: dict endpoint_name: Optional[str] = '' need_sync: bool = True number: Optional[str] = None front: Optional[bool] = None extra_data: Optional[dict] = None client_id: Optional[str] = None workflow: Optional[str] = None need_prepare: Optional[bool] = False prepare_props: Optional[PrepareProps] = None multi_async: Optional[bool] = False workflow_name: Optional[str] = None def sen_sqs_msg(message_body, endpoint_name): sqs_client = boto3.client('sqs', region_name=region) response = sqs_client.send_message( QueueUrl=sqs_url, MessageBody=json.dumps(message_body), MessageGroupId=endpoint_name ) message_id = response['MessageId'] return message_id def build_s3_images_request(prompt_id, bucket_name, s3_path): s3 = boto3.client('s3', region_name=region) response = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_path) image_video_dict = {} for obj in response.get('Contents', []): object_key = obj['Key'] file_name = object_key.split('/')[-1] if object_key.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.mp4', '.mov', '.avi')): response_data = s3.get_object(Bucket=bucket_name, Key=object_key) object_data = response_data['Body'].read() encoded_data = base64.b64encode(object_data).decode('utf-8') image_video_dict[file_name] = encoded_data return {'prompt_id': prompt_id, 'image_video_data': image_video_dict} @tracer.capture_method def invoke_sagemaker_inference(event: ExecuteEvent): if not event.endpoint_name and not event.workflow_name: raise Exception(f"Cannot match an available environment,please check your EndpointName or your WorkflowVersion.") if event.workflow_name: endpoint_name = get_endpoint_name_by_workflow_name(name=event.workflow_name) else: endpoint_name = event.endpoint_name try: ep = get_endpoint_by_name(endpoint_name) except Exception as e: raise Exception(f"Please create an endpoint first, then try again.") if ep.endpoint_status not in [EndpointStatus.IN_SERVICE.value, EndpointStatus.UPDATING.value]: raise Exception(f"Endpoint {endpoint_name} is {ep.endpoint_status} status, not InService or Updating.") if event.workflow: event.workflow = get_workflow_name(event.workflow, ep.instance_type) logger.info(f"endpoint: {ep}") payload = event.__dict__ logger.info('inference payload: {}'.format(payload)) inference_id = str(uuid.uuid4()) job_status = ComfyExecuteType.CREATED.value inference_job = ComfyExecuteTable( prompt_id=event.prompt_id, endpoint_name=event.endpoint_name, inference_type=ep.endpoint_type, instance_id=endpoint_instance_id, need_sync=event.need_sync, status=job_status, prompt_params={'prompt': str(event.prompt), 'number': event.number, 'front': event.front, 'extra_data': str(event.extra_data), 'client_id': event.client_id}, prompt_path='', create_time=datetime.now().isoformat(), sagemaker_raw={}, output_path='', temp_path='', output_files=[], temp_files=[], multi_async=event.multi_async, batch_id='', workflow=event.workflow, ) logger.info(f"inference job: {inference_job.__dict__}") record_count_metrics(ep_name=ep.endpoint_name, metric_name='InferenceTotal', service=ServiceType.Comfy.value, workflow=inference_job.workflow ) if event.multi_async and ep.endpoint_type == 'Async' and ep.instance_type in multi_gpu_instance_type_list: logger.info(f"executing multi-gpu inference {ep.instance_type} {ep.endpoint_type} {event.multi_async}") save_item = inference_job.__dict__ ddb_service.put_items(execute_table, entries=save_item) sen_sqs_msg({"event": payload, "save_item": save_item, "inference_id": inference_id}, endpoint_name) # just for test multi gpu # payload1 = payload # payload1['prompt_id'] = payload['prompt_id']+"1" # save_item1 = save_item # save_item1['prompt_id'] = payload['prompt_id']+"1" # sen_sqs_msg({"event": payload1, "save_item": save_item1, "inference_id": inference_id+"1"}, endpoint_name) # payload2 = payload # payload2['prompt_id'] = payload['prompt_id'] + "1" # save_item2 = save_item # save_item2['prompt_id'] = payload['prompt_id'] + "1" # sen_sqs_msg({"event": payload2, "save_item": save_item2, "inference_id": inference_id+"2"}, endpoint_name) return created(data=response_schema(inference_job), decimal=True) elif ep.endpoint_type == 'Async': ddb_service.put_items(execute_table, entries=inference_job.__dict__) resp = async_inference([payload], inference_id, ep.endpoint_name) # TODO status check and save logger.info(f"async inference {ep.instance_type} {ep.endpoint_type} {event.multi_async} response: {resp}") return created(data=response_schema(inference_job), decimal=True) inference_job.start_time = datetime.now().isoformat() ddb_service.put_items(execute_table, entries=inference_job.__dict__) resp = real_time_inference([payload], inference_id, ep.endpoint_name) logger.info(f"real time inference response: {resp}") if resp and len(resp) > 0: resp = InferenceResult(**resp[0]) resp = s3_scan_files(resp) inference_job.status = resp.status inference_job.sagemaker_raw = resp.__dict__ inference_job.output_path = resp.output_path inference_job.output_files = resp.output_files inference_job.temp_path = resp.temp_path inference_job.temp_files = resp.temp_files inference_job.complete_time = datetime.now().isoformat() if resp.status == 'fail' or resp.status != 'Completed' or resp.status != 'success': inference_job.message = resp.message ddb_service.put_items(execute_table, entries=inference_job.__dict__) if ep.endpoint_type == 'Real-time': inference_job.output_files = generate_presigned_url_for_keys(inference_job.output_path, inference_job.output_files) inference_job.temp_files = generate_presigned_url_for_keys(inference_job.temp_path, inference_job.temp_files) if resp.status == 'fail' or resp.status != 'Completed' or resp.status != 'success': inference_job.message = resp.message record_count_metrics(ep_name=ep.endpoint_name, metric_name='InferenceFailed', service=ServiceType.Comfy.value, workflow=inference_job.workflow ) else: record_count_metrics(ep_name=ep.endpoint_name, metric_name='InferenceSucceed', service=ServiceType.Comfy.value, workflow=inference_job.workflow ) record_latency_metrics(start_time=inference_job.start_time, ep_name=ep.endpoint_name, metric_name='InferenceLatency', workflow=inference_job.workflow, service=ServiceType.Comfy.value) else: logger.info(f"inference error by sg resp none!{resp}") record_count_metrics(ep_name=ep.endpoint_name, metric_name='InferenceFailed', service=ServiceType.Comfy.value, workflow=inference_job.workflow ) return ok(data=response_schema(inference_job), decimal=True) def response_schema(inference_job: ComfyExecuteTable): if not inference_job.output_files: inference_job.output_files = [] if not inference_job.temp_files: inference_job.temp_files = [] data = { 'prompt_id': inference_job.prompt_id, 'status': inference_job.status, 'create_time': inference_job.create_time, 'endpoint_name': inference_job.endpoint_name, 'inference_type': inference_job.inference_type, 'need_sync': inference_job.need_sync, 'start_time': inference_job.start_time, 'complete_time': inference_job.complete_time, 'output_path': inference_job.output_path, 'output_files': inference_job.output_files, 'temp_path': inference_job.temp_path, 'temp_files': inference_job.temp_files, } return data @tracer.capture_lambda_handler def handler(raw_event, ctx): try: logger.info(f"execute start... Received event: {raw_event}") logger.info(f"Received ctx: {ctx}") event = ExecuteEvent(**json.loads(raw_event['body'])) if not event.prompt: raise BadRequestException("Prompt is required") return invoke_sagemaker_inference(event) except Exception as e: return response_error(e) @tracer.capture_method def async_inference(payload: any, inference_id, endpoint_name): tracer.put_annotation(key="inference_id", value=inference_id) initial_args = {"InvocationTimeoutSeconds": 3600} return get_async_predict_client(endpoint_name).predict_async(data=payload, initial_args=initial_args, inference_id=inference_id) @tracer.capture_method def real_time_inference(data: any, inference_id, endpoint_name): tracer.put_annotation(key="inference_id", value=inference_id) return get_real_time_predict_client(endpoint_name).predict(data=data, inference_id=inference_id) @tracer.capture_method def get_real_time_predict_client(endpoint_name): tracer.put_annotation(key="endpoint_name", value=endpoint_name) if endpoint_name in predictors: return predictors[endpoint_name] predictor = Predictor(endpoint_name) predictor.serializer = JSONSerializer() predictor.deserializer = JSONDeserializer() predictors[endpoint_name] = predictor return predictor @tracer.capture_method def get_async_predict_client(endpoint_name): tracer.put_annotation(key="endpoint_name", value=endpoint_name) if endpoint_name in predictors: return predictors[endpoint_name] predictor = Predictor(endpoint_name) predictor = AsyncPredictor(predictor, name=endpoint_name) predictor.serializer = JSONSerializer() predictor.deserializer = JSONDeserializer() predictors[endpoint_name] = predictor return predictor