stable-diffusion-aws-extension/middleware_api/lambda/inference_v2/inference_api.py

271 lines
9.6 KiB
Python

import dataclasses
import logging
import os
from datetime import datetime
from typing import List, Any
from sagemaker import Predictor
from sagemaker.deserializers import JSONDeserializer
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.serializers import JSONSerializer
from common.ddb_service.client import DynamoDbUtilsService
from common.util import generate_presign_url, load_json_from_s3, upload_json_to_s3, split_s3_path
from inference_v2._types import InferenceJob, InvocationsRequest
from model_and_train._types import CheckPoint
bucket_name = os.environ.get('S3_BUCKET')
checkpoint_table = os.environ.get('CHECKPOINT_TABLE')
sagemaker_endpoint_table = os.environ.get('DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME')
inference_table_name = os.environ.get('DDB_INFERENCE_TABLE_NAME')
logger = logging.getLogger('inference_v2')
ddb_service = DynamoDbUtilsService(logger=logger)
@dataclasses.dataclass
class PrepareEvent:
sagemaker_endpoint_name: str
task_type: str
models: dict[str, List[str]] # [checkpoint_type: names] this is same as checkpoint if confused
filters: dict[str, Any]
# POST /v2/inference
def prepare_inference(raw_event, context):
request_id = context.aws_request_id
event = PrepareEvent(**raw_event)
_type = event.task_type
if _type not in ['txt2img', 'img2img']:
return {
'status': 400,
'err': f'task type {event.task_type} should be either txt2img or img2img'
}
# check if endpoint table for endpoint status and existence
# fixme: endpoint is not indexed by name, and this is very expensive query
# fixme: we can either add index for endpoint name or make endpoint as the partition key
sagemaker_endpoint_raw = ddb_service.scan(sagemaker_endpoint_table, filters={
'endpoint_name': event.sagemaker_endpoint_name
})[0]
sagemaker_endpoint_raw = ddb_service.deserialize(sagemaker_endpoint_raw)
if sagemaker_endpoint_raw is None or len(sagemaker_endpoint_raw) == 0:
return {
'status': 500,
'error': f'sagemaker endpoint with name {event.sagemaker_endpoint_name} is not found'
}
if sagemaker_endpoint_raw['endpoint_status'] != 'InService':
return {
'status': 400,
'error': f'sagemaker endpoint is not ready with status: {sagemaker_endpoint_raw["endpoint_status"]}'
}
endpoint_name = sagemaker_endpoint_raw['endpoint_name']
endpoint_id = sagemaker_endpoint_raw['EndpointDeploymentJobId']
# check if model(checkpoint) path(s) exists. return error if not
ckpts = []
ckpts_to_upload = []
for ckpt_type, names in event.models.items():
for name in names:
ckpt = _get_checkpoint_by_name(name, ckpt_type)
if ckpt is None:
ckpts_to_upload.append({
'name': name,
'ckpt_type': ckpt_type
})
else:
ckpts.append(ckpt)
if len(ckpts_to_upload) > 0:
return {
'status': 400,
'error': [f'checkpoint with name {c["name"]}, type {c["ckpt_type"]} is not found' for c in ckpts_to_upload]
}
# generate param s3 location for upload
param_s3_key = f'{get_base_inference_param_s3_key(_type, request_id)}/api_param.json'
s3_location = f's3://{bucket_name}/{param_s3_key}'
presign_url = generate_presign_url(bucket_name, param_s3_key)
# create inference job with param location in ddb, status set to Created
used_models = {}
for ckpt in ckpts:
if ckpt.checkpoint_type not in used_models:
used_models[ckpt.checkpoint_type] = []
used_models[ckpt.checkpoint_type].append(
{
'id': ckpt.id,
'model_name': ckpt.checkpoint_names[0],
's3': ckpt.s3_location,
'type': ckpt.checkpoint_type
}
)
inference_job = InferenceJob(
InferenceJobId=request_id,
startTime=str(datetime.now()),
status='created',
taskType=_type,
params={
'input_body_s3': s3_location,
'input_body_presign_url': presign_url,
'used_models': used_models,
'sagemaker_inference_endpoint_id': endpoint_id,
'sagemaker_inference_endpoint_name': endpoint_name,
},
)
ddb_service.put_items(inference_table_name, entries=inference_job.__dict__)
return {
'status': 200,
'inference': {
'id': request_id,
'type': _type,
'api_params_s3_location': s3_location,
'api_params_s3_upload_url': presign_url,
'models': [{'id': ckpt.id, 'name': ckpt.checkpoint_names, 'type': ckpt.checkpoint_type} for ckpt in ckpts]
}
}
# PUT /v2/inference/{inference_id}/run
def run_inference(event, _):
_filter = {}
if 'pathStringParameters' not in event:
return {
'statusCode': '500',
'error': 'path parameter /v2/inference/{inference_id}/run are needed'
}
infer_id = event['pathStringParameters']['inference_id']
if not infer_id or len(infer_id) == 0:
return {
'statusCode': '500',
'error': 'path parameter /v2/inference/{inference_id}/run are needed, typically inference id is not found'
}
# get the inference job from ddb by job id
inference_raw = ddb_service.get_item(inference_table_name, {
'InferenceJobId': infer_id
})
assert inference_raw is not None and len(inference_raw) > 0
inference_job = InferenceJob(**inference_raw)
endpoint_name = inference_job.params['sagemaker_inference_endpoint_name']
# payload = inference_job.params
payload = InvocationsRequest(
task=inference_job.taskType,
username="test",
models={
"space_free_size": 4e10,
**inference_job.params['used_models'],
},
param_s3=inference_job.params['input_body_s3']
)
# start async inference
predictor = Predictor(endpoint_name)
initial_args = {"InvocationTimeoutSeconds": 3600}
predictor = AsyncPredictor(predictor, name=endpoint_name)
predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()
prediction = predictor.predict_async(data=payload.__dict__, initial_args=initial_args, inference_id=infer_id)
output_path = prediction.output_path
# update the ddb job status to 'inprogress' and save to ddb
inference_job.status = 'inprogress'
inference_job.params['output_path'] = output_path
ddb_service.put_items(inference_table_name, inference_job.__dict__)
return {
'status': 200,
'inference': {
'inference_id': infer_id,
'status': inference_job.status,
'endpoint_name': endpoint_name,
'output_path': output_path
}
}
# POST /inference-api
def inference_l2(raw_event, context):
request_id = context.aws_request_id
if 'task_type' not in raw_event or 'sagemaker_endpoint_name' not in raw_event or 'models' not in raw_event:
return {
'status': 400,
'err': f'task_type, sagemaker_endpoint_name, models should be in the request body'
}
task_type = raw_event['task_type']
ep_name = raw_event['sagemaker_endpoint_name']
prepare_infer_event = {
'sagemaker_endpoint_name': raw_event['sagemaker_endpoint_name'],
'task_type': task_type,
'models': raw_event['models'],
'filters': {
'creator': 'l2api'
}
}
prepare_resp = prepare_inference(prepare_infer_event, context)
if 'inference' not in prepare_resp or 'api_params_s3_location' not in prepare_resp['inference']:
return {
'status': 500,
'err': f'fail to prepare inference for {request_id}, no s3 location is generated'
}
s3_location = prepare_resp['inference']['api_params_s3_location']
bucket, s3_file_key = split_s3_path(s3_location)
# merge the parameters with template
param_template = load_json_from_s3(bucket_name, 'template/inferenceTemplate.json')
merged_param = {**param_template, **raw_event}
upload_json_to_s3(bucket_name, s3_file_key, merged_param)
run_infer_resp = run_inference({
'pathStringParameters': {
'inference_id': request_id
}
}, context)
return {
'status': 200,
'inference': {
'inference_id': request_id,
'status': run_infer_resp['inference']['status'],
'output_path': run_infer_resp['inference']['output_path'],
'models': prepare_resp['inference']['models'],
'api_params_s3_location': s3_location,
'type': task_type,
'endpoint_name': ep_name
}
}
# fixme: this is a very expensive function
def _get_checkpoint_by_name(ckpt_name, model_type, status='Active') -> CheckPoint:
checkpoint_raw = ddb_service.client.scan(
TableName=checkpoint_table,
FilterExpression='contains(checkpoint_names, :checkpointName) and checkpoint_type=:model_type and checkpoint_status=:checkpoint_status',
ExpressionAttributeValues={
':checkpointName': {'S': ckpt_name},
':model_type': {'S': model_type},
':checkpoint_status': {'S': status}
}
)
from common.ddb_service.types_ import ScanOutput
named_ = ScanOutput(**checkpoint_raw)
if checkpoint_raw is None or len(named_['Items']) == 0:
return None
return CheckPoint(**ddb_service.deserialize(named_['Items'][0]))
def get_base_inference_param_s3_key(_type: str, request_id: str) -> str:
return f'{_type}/infer_v2/{request_id}'