198 lines
6.7 KiB
Python
198 lines
6.7 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
from datetime import datetime
|
|
|
|
from aws_lambda_powertools import Tracer
|
|
from sagemaker import Predictor
|
|
from sagemaker.deserializers import JSONDeserializer
|
|
from sagemaker.predictor_async import AsyncPredictor
|
|
from sagemaker.serializers import JSONSerializer
|
|
|
|
from common.const import PERMISSION_INFERENCE_ALL
|
|
from common.ddb_service.client import DynamoDbUtilsService
|
|
from common.excepts import BadRequestException
|
|
from common.response import accepted
|
|
from common.util import record_latency_metrics, record_count_metrics
|
|
from get_inference_job import get_infer_data
|
|
from inference_libs import parse_sagemaker_result, update_inference_job_table
|
|
from libs.data_types import InferenceJob, InvocationRequest
|
|
from libs.enums import EndpointType
|
|
from libs.utils import response_error, permissions_check, log_json
|
|
|
|
tracer = Tracer()
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
|
|
|
|
inference_table_name = os.environ.get('INFERENCE_JOB_TABLE')
|
|
|
|
ddb_service = DynamoDbUtilsService(logger=logger)
|
|
|
|
predictors = {}
|
|
|
|
|
|
@tracer.capture_lambda_handler
|
|
def handler(event: dict, _):
|
|
try:
|
|
logger.info(json.dumps(event))
|
|
_filter = {}
|
|
|
|
inference_id = event['pathParameters']['id']
|
|
|
|
if not inference_id:
|
|
raise BadRequestException("InferenceJobId is required")
|
|
|
|
username = permissions_check(event, [PERMISSION_INFERENCE_ALL])
|
|
|
|
# get the inference job from ddb by job id
|
|
inference_raw = ddb_service.get_item(inference_table_name, {
|
|
'InferenceJobId': inference_id
|
|
})
|
|
|
|
if inference_raw is None or len(inference_raw) == 0:
|
|
raise BadRequestException(f"InferenceJobId {inference_id} not found")
|
|
|
|
job = InferenceJob(**inference_raw)
|
|
|
|
return inference_start(job, username)
|
|
|
|
except Exception as e:
|
|
return response_error(e)
|
|
|
|
|
|
@tracer.capture_method
|
|
def inference_start(job: InferenceJob, username):
|
|
endpoint_name = job.params['sagemaker_inference_endpoint_name']
|
|
models = {}
|
|
if 'used_models' in job.params:
|
|
models = {
|
|
"space_free_size": 4e10,
|
|
**job.params['used_models'],
|
|
}
|
|
|
|
payload = InvocationRequest(
|
|
id=job.InferenceJobId,
|
|
task=job.taskType,
|
|
workflow=job.workflow,
|
|
username=username,
|
|
models=models,
|
|
param_s3=job.params['input_body_s3'],
|
|
payload_string=job.payload_string
|
|
)
|
|
|
|
log_json("inference job", job.__dict__)
|
|
log_json("inference invoke payload", payload.__dict__)
|
|
|
|
if job.inference_type == EndpointType.RealTime.value:
|
|
update_inference_job_table(job.InferenceJobId, 'startTime', datetime.now().isoformat())
|
|
return real_time_inference(payload, job, endpoint_name)
|
|
|
|
return async_inference(payload, job, endpoint_name)
|
|
|
|
|
|
@tracer.capture_method
|
|
def real_time_inference(payload: InvocationRequest, job: InferenceJob, ep_name: str):
|
|
tracer.put_annotation(key="InferenceJobId", value=job.InferenceJobId)
|
|
sagemaker_out = predictor_real_time_predict(endpoint_name=ep_name,
|
|
data=payload.__dict__,
|
|
inference_id=job.InferenceJobId,
|
|
)
|
|
|
|
if 'error' in sagemaker_out:
|
|
record_count_metrics(ep_name=ep_name,
|
|
metric_name='InferenceFailed',
|
|
workflow=job.workflow,
|
|
)
|
|
update_inference_job_table(job.InferenceJobId, 'sagemakerRaw', str(sagemaker_out))
|
|
raise Exception(str(sagemaker_out))
|
|
|
|
parse_sagemaker_result(sagemaker_out, job.createTime, job.InferenceJobId, job.taskType, ep_name)
|
|
|
|
record_count_metrics(ep_name=ep_name,
|
|
metric_name='InferenceSucceed',
|
|
workflow=job.workflow,
|
|
)
|
|
record_latency_metrics(start_time=sagemaker_out['start_time'],
|
|
ep_name=ep_name,
|
|
metric_name='InferenceLatency',
|
|
workflow=job.workflow,
|
|
)
|
|
|
|
return get_infer_data(job.InferenceJobId)
|
|
|
|
|
|
@tracer.capture_method
|
|
def get_real_time_predict_client(endpoint_name):
|
|
tracer.put_annotation(key="endpoint_name", value=endpoint_name)
|
|
if endpoint_name in predictors:
|
|
return predictors[endpoint_name]
|
|
|
|
predictor = Predictor(endpoint_name)
|
|
predictor.serializer = JSONSerializer()
|
|
predictor.deserializer = JSONDeserializer()
|
|
|
|
predictors[endpoint_name] = predictor
|
|
|
|
return predictor
|
|
|
|
|
|
@tracer.capture_method
|
|
def get_async_predict_client(endpoint_name):
|
|
tracer.put_annotation(key="endpoint_name", value=endpoint_name)
|
|
if endpoint_name in predictors:
|
|
return predictors[endpoint_name]
|
|
|
|
predictor = Predictor(endpoint_name)
|
|
predictor = AsyncPredictor(predictor, name=endpoint_name)
|
|
predictor.serializer = JSONSerializer()
|
|
predictor.deserializer = JSONDeserializer()
|
|
|
|
predictors[endpoint_name] = predictor
|
|
|
|
return predictor
|
|
|
|
|
|
@tracer.capture_method
|
|
def predictor_real_time_predict(endpoint_name, data, inference_id):
|
|
return get_real_time_predict_client(endpoint_name).predict(data=data, inference_id=inference_id)
|
|
|
|
|
|
@tracer.capture_method
|
|
def predictor_async_predict(endpoint_name, data, inference_id):
|
|
tracer.put_annotation(key="inference_id", value=inference_id)
|
|
initial_args = {"InvocationTimeoutSeconds": 3600}
|
|
return get_async_predict_client(endpoint_name).predict_async(data=data,
|
|
initial_args=initial_args,
|
|
inference_id=inference_id)
|
|
|
|
|
|
@tracer.capture_method
|
|
def async_inference(payload: InvocationRequest, job: InferenceJob, endpoint_name):
|
|
tracer.put_annotation(key="inference_id", value=job.InferenceJobId)
|
|
|
|
prediction = predictor_async_predict(endpoint_name=endpoint_name,
|
|
data=payload.__dict__,
|
|
inference_id=job.InferenceJobId)
|
|
logger.info(f"prediction: {prediction}")
|
|
output_path = prediction.output_path
|
|
|
|
# update the ddb job status to 'inprogress' and save to ddb
|
|
job.status = 'inprogress'
|
|
job.params['output_path'] = output_path
|
|
ddb_service.put_items(inference_table_name, job.__dict__)
|
|
|
|
data = {
|
|
'InferenceJobId': job.InferenceJobId,
|
|
'status': job.status,
|
|
|
|
# todo inference will remove in the next version
|
|
'inference': {
|
|
'inference_id': job.InferenceJobId,
|
|
'status': job.status,
|
|
'endpoint_name': endpoint_name,
|
|
'output_path': output_path
|
|
}
|
|
}
|
|
|
|
return accepted(data=data)
|