stable-diffusion-aws-extension/middleware_api/inferences/start_inference_job.py

198 lines
6.7 KiB
Python

import json
import logging
import os
from datetime import datetime
from aws_lambda_powertools import Tracer
from sagemaker import Predictor
from sagemaker.deserializers import JSONDeserializer
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.serializers import JSONSerializer
from common.const import PERMISSION_INFERENCE_ALL
from common.ddb_service.client import DynamoDbUtilsService
from common.excepts import BadRequestException
from common.response import accepted
from common.util import record_latency_metrics, record_count_metrics
from get_inference_job import get_infer_data
from inference_libs import parse_sagemaker_result, update_inference_job_table
from libs.data_types import InferenceJob, InvocationRequest
from libs.enums import EndpointType
from libs.utils import response_error, permissions_check, log_json
tracer = Tracer()
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
inference_table_name = os.environ.get('INFERENCE_JOB_TABLE')
ddb_service = DynamoDbUtilsService(logger=logger)
predictors = {}
@tracer.capture_lambda_handler
def handler(event: dict, _):
try:
logger.info(json.dumps(event))
_filter = {}
inference_id = event['pathParameters']['id']
if not inference_id:
raise BadRequestException("InferenceJobId is required")
username = permissions_check(event, [PERMISSION_INFERENCE_ALL])
# get the inference job from ddb by job id
inference_raw = ddb_service.get_item(inference_table_name, {
'InferenceJobId': inference_id
})
if inference_raw is None or len(inference_raw) == 0:
raise BadRequestException(f"InferenceJobId {inference_id} not found")
job = InferenceJob(**inference_raw)
return inference_start(job, username)
except Exception as e:
return response_error(e)
@tracer.capture_method
def inference_start(job: InferenceJob, username):
endpoint_name = job.params['sagemaker_inference_endpoint_name']
models = {}
if 'used_models' in job.params:
models = {
"space_free_size": 4e10,
**job.params['used_models'],
}
payload = InvocationRequest(
id=job.InferenceJobId,
task=job.taskType,
workflow=job.workflow,
username=username,
models=models,
param_s3=job.params['input_body_s3'],
payload_string=job.payload_string
)
log_json("inference job", job.__dict__)
log_json("inference invoke payload", payload.__dict__)
if job.inference_type == EndpointType.RealTime.value:
update_inference_job_table(job.InferenceJobId, 'startTime', datetime.now().isoformat())
return real_time_inference(payload, job, endpoint_name)
return async_inference(payload, job, endpoint_name)
@tracer.capture_method
def real_time_inference(payload: InvocationRequest, job: InferenceJob, ep_name: str):
tracer.put_annotation(key="InferenceJobId", value=job.InferenceJobId)
sagemaker_out = predictor_real_time_predict(endpoint_name=ep_name,
data=payload.__dict__,
inference_id=job.InferenceJobId,
)
if 'error' in sagemaker_out:
record_count_metrics(ep_name=ep_name,
metric_name='InferenceFailed',
workflow=job.workflow,
)
update_inference_job_table(job.InferenceJobId, 'sagemakerRaw', str(sagemaker_out))
raise Exception(str(sagemaker_out))
parse_sagemaker_result(sagemaker_out, job.createTime, job.InferenceJobId, job.taskType, ep_name)
record_count_metrics(ep_name=ep_name,
metric_name='InferenceSucceed',
workflow=job.workflow,
)
record_latency_metrics(start_time=sagemaker_out['start_time'],
ep_name=ep_name,
metric_name='InferenceLatency',
workflow=job.workflow,
)
return get_infer_data(job.InferenceJobId)
@tracer.capture_method
def get_real_time_predict_client(endpoint_name):
tracer.put_annotation(key="endpoint_name", value=endpoint_name)
if endpoint_name in predictors:
return predictors[endpoint_name]
predictor = Predictor(endpoint_name)
predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()
predictors[endpoint_name] = predictor
return predictor
@tracer.capture_method
def get_async_predict_client(endpoint_name):
tracer.put_annotation(key="endpoint_name", value=endpoint_name)
if endpoint_name in predictors:
return predictors[endpoint_name]
predictor = Predictor(endpoint_name)
predictor = AsyncPredictor(predictor, name=endpoint_name)
predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()
predictors[endpoint_name] = predictor
return predictor
@tracer.capture_method
def predictor_real_time_predict(endpoint_name, data, inference_id):
return get_real_time_predict_client(endpoint_name).predict(data=data, inference_id=inference_id)
@tracer.capture_method
def predictor_async_predict(endpoint_name, data, inference_id):
tracer.put_annotation(key="inference_id", value=inference_id)
initial_args = {"InvocationTimeoutSeconds": 3600}
return get_async_predict_client(endpoint_name).predict_async(data=data,
initial_args=initial_args,
inference_id=inference_id)
@tracer.capture_method
def async_inference(payload: InvocationRequest, job: InferenceJob, endpoint_name):
tracer.put_annotation(key="inference_id", value=job.InferenceJobId)
prediction = predictor_async_predict(endpoint_name=endpoint_name,
data=payload.__dict__,
inference_id=job.InferenceJobId)
logger.info(f"prediction: {prediction}")
output_path = prediction.output_path
# update the ddb job status to 'inprogress' and save to ddb
job.status = 'inprogress'
job.params['output_path'] = output_path
ddb_service.put_items(inference_table_name, job.__dict__)
data = {
'InferenceJobId': job.InferenceJobId,
'status': job.status,
# todo inference will remove in the next version
'inference': {
'inference_id': job.InferenceJobId,
'status': job.status,
'endpoint_name': endpoint_name,
'output_path': output_path
}
}
return accepted(data=data)