206 lines
6.8 KiB
Python
206 lines
6.8 KiB
Python
import base64
|
|
import json
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
import boto3
|
|
from aws_lambda_powertools import Tracer
|
|
from sagemaker import Predictor
|
|
from sagemaker.deserializers import JSONDeserializer
|
|
from sagemaker.predictor_async import AsyncPredictor
|
|
from sagemaker.serializers import JSONSerializer
|
|
|
|
from common.ddb_service.client import DynamoDbUtilsService
|
|
from common.excepts import BadRequestException
|
|
from common.response import ok
|
|
from common.util import s3_scan_files
|
|
from libs.comfy_data_types import ComfyExecuteTable, InferenceResult
|
|
from libs.enums import ComfyExecuteType
|
|
from libs.utils import get_endpoint_by_name, response_error
|
|
|
|
tracer = Tracer()
|
|
region = os.environ.get('AWS_REGION')
|
|
bucket_name = os.environ.get('S3_BUCKET_NAME')
|
|
execute_table = os.environ.get('EXECUTE_TABLE')
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
|
|
endpoint_instance_id = os.environ.get('ENDPOINT_INSTANCE_ID')
|
|
ddb_service = DynamoDbUtilsService(logger=logger)
|
|
|
|
index_name = "endpoint_name-startTime-index"
|
|
predictors = {}
|
|
|
|
|
|
@dataclass
|
|
class PrepareProps:
|
|
prepare_type: Optional[str] = "inputs"
|
|
s3_source_path: Optional[str] = None
|
|
local_target_path: Optional[str] = None
|
|
sync_script: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class ExecuteEvent:
|
|
prompt_id: str
|
|
prompt: dict
|
|
endpoint_name: Optional[str] = ''
|
|
need_sync: bool = True
|
|
number: Optional[str] = None
|
|
front: Optional[bool] = None
|
|
extra_data: Optional[dict] = None
|
|
client_id: Optional[str] = None
|
|
need_prepare: bool = False
|
|
prepare_props: Optional[PrepareProps] = None
|
|
|
|
|
|
def build_s3_images_request(prompt_id, bucket_name, s3_path):
|
|
s3 = boto3.client('s3', region_name=region)
|
|
response = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_path)
|
|
image_video_dict = {}
|
|
for obj in response.get('Contents', []):
|
|
object_key = obj['Key']
|
|
file_name = object_key.split('/')[-1]
|
|
if object_key.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.mp4', '.mov', '.avi')):
|
|
response_data = s3.get_object(Bucket=bucket_name, Key=object_key)
|
|
object_data = response_data['Body'].read()
|
|
encoded_data = base64.b64encode(object_data).decode('utf-8')
|
|
image_video_dict[file_name] = encoded_data
|
|
|
|
return {'prompt_id': prompt_id, 'image_video_data': image_video_dict}
|
|
|
|
|
|
@tracer.capture_method
|
|
def invoke_sagemaker_inference(event: ExecuteEvent):
|
|
endpoint_name = event.endpoint_name
|
|
|
|
ep = get_endpoint_by_name(endpoint_name)
|
|
|
|
if ep.endpoint_status != 'InService':
|
|
raise Exception(f"Endpoint {endpoint_name} is {ep.endpoint_status} status, not InService.")
|
|
|
|
logger.info(f"endpoint: {ep}")
|
|
|
|
payload = event.__dict__
|
|
logger.info('inference payload: {}'.format(payload))
|
|
|
|
inference_id = str(uuid.uuid4())
|
|
|
|
job_status = ComfyExecuteType.CREATED.value
|
|
|
|
inference_job = ComfyExecuteTable(
|
|
prompt_id=event.prompt_id,
|
|
endpoint_name=event.endpoint_name,
|
|
inference_type=ep.endpoint_type,
|
|
instance_id=endpoint_instance_id,
|
|
need_sync=event.need_sync,
|
|
status=job_status,
|
|
prompt_params={'prompt': event.prompt,
|
|
'number': event.number,
|
|
'front': event.front,
|
|
'extra_data': event.extra_data,
|
|
'client_id': event.client_id},
|
|
prompt_path='',
|
|
create_time=datetime.now().isoformat(),
|
|
start_time=datetime.now().isoformat(),
|
|
sagemaker_raw={},
|
|
output_path='',
|
|
)
|
|
|
|
if ep.endpoint_type == 'Async':
|
|
resp = async_inference(payload, inference_id, ep.endpoint_name)
|
|
# inference_job.sagemaker_raw = resp.__dict__
|
|
else:
|
|
# resp = real_time_inference(payload, inference_id, ep.endpoint_name)
|
|
resp = {
|
|
"prompt_id": '11111111-1111-1111',
|
|
"instance_id": 'esd-real-time-test-rgihbd',
|
|
"status": 'success',
|
|
"output_path": f's3://{bucket_name}/images/',
|
|
"temp_path": f's3://{bucket_name}/images/'
|
|
}
|
|
logger.info(f"real time inference response: ")
|
|
logger.info(resp)
|
|
resp = InferenceResult(**resp)
|
|
resp = s3_scan_files(resp)
|
|
|
|
inference_job.status = resp.status
|
|
inference_job.sagemaker_raw = resp.__dict__
|
|
inference_job.output_path = resp.output_path
|
|
inference_job.output_files = resp.output_files
|
|
inference_job.temp_path = resp.temp_path
|
|
inference_job.temp_files = resp.temp_files
|
|
inference_job.complete_time = datetime.now().isoformat()
|
|
|
|
ddb_service.put_items(execute_table, entries=inference_job.__dict__)
|
|
|
|
return inference_job
|
|
|
|
|
|
@tracer.capture_lambda_handler
|
|
def handler(raw_event, ctx):
|
|
try:
|
|
logger.info(f"execute start... Received event: {raw_event}")
|
|
logger.info(f"Received ctx: {ctx}")
|
|
event = ExecuteEvent(**json.loads(raw_event['body']))
|
|
|
|
if not event.prompt:
|
|
raise BadRequestException("Prompt is required")
|
|
|
|
resp = invoke_sagemaker_inference(event)
|
|
|
|
return ok(data=resp.__dict__)
|
|
|
|
except Exception as e:
|
|
return response_error(e)
|
|
|
|
|
|
@tracer.capture_method
|
|
def async_inference(payload: any, inference_id, endpoint_name):
|
|
tracer.put_annotation(key="inference_id", value=inference_id)
|
|
initial_args = {"InvocationTimeoutSeconds": 3600}
|
|
return get_async_predict_client(endpoint_name).predict_async(data=payload,
|
|
initial_args=initial_args,
|
|
inference_id=inference_id)
|
|
|
|
|
|
@tracer.capture_method
|
|
def real_time_inference(data: any, inference_id, endpoint_name):
|
|
tracer.put_annotation(key="inference_id", value=inference_id)
|
|
return get_real_time_predict_client(endpoint_name).predict(data=data, inference_id=inference_id)
|
|
|
|
|
|
@tracer.capture_method
|
|
def get_real_time_predict_client(endpoint_name):
|
|
tracer.put_annotation(key="endpoint_name", value=endpoint_name)
|
|
if endpoint_name in predictors:
|
|
return predictors[endpoint_name]
|
|
|
|
predictor = Predictor(endpoint_name)
|
|
predictor.serializer = JSONSerializer()
|
|
predictor.deserializer = JSONDeserializer()
|
|
|
|
predictors[endpoint_name] = predictor
|
|
|
|
return predictor
|
|
|
|
|
|
@tracer.capture_method
|
|
def get_async_predict_client(endpoint_name):
|
|
tracer.put_annotation(key="endpoint_name", value=endpoint_name)
|
|
if endpoint_name in predictors:
|
|
return predictors[endpoint_name]
|
|
|
|
predictor = Predictor(endpoint_name)
|
|
predictor = AsyncPredictor(predictor, name=endpoint_name)
|
|
predictor.serializer = JSONSerializer()
|
|
predictor.deserializer = JSONDeserializer()
|
|
|
|
predictors[endpoint_name] = predictor
|
|
|
|
return predictor
|