stable-diffusion-aws-extension/middleware_api/lambda/inferences/inference_api.py

347 lines
13 KiB
Python

import dataclasses
import json
import logging
import os
import random
from datetime import datetime
from typing import List, Any, Optional
from sagemaker import Predictor
from sagemaker.deserializers import JSONDeserializer
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.serializers import JSONSerializer
from common.ddb_service.client import DynamoDbUtilsService
from common.response import ok, bad_request
from common.util import generate_presign_url, load_json_from_s3, upload_json_to_s3, split_s3_path
from libs.data_types import CheckPoint, CheckPointStatus
from libs.data_types import InferenceJob, InvocationsRequest, EndpointDeploymentJob
from libs.utils import get_user_roles, check_user_permissions
bucket_name = os.environ.get('S3_BUCKET')
checkpoint_table = os.environ.get('CHECKPOINT_TABLE')
sagemaker_endpoint_table = os.environ.get('DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME')
inference_table_name = os.environ.get('INFERENCE_JOB_TABLE')
user_table = os.environ.get('MULTI_USER_TABLE')
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
ddb_service = DynamoDbUtilsService(logger=logger)
@dataclasses.dataclass
class PrepareEvent:
task_type: str
models: dict[str, List[str]] # [checkpoint_type: names] this is same as checkpoint if confused
filters: dict[str, Any]
sagemaker_endpoint_name: Optional[str] = ""
user_id: Optional[str] = ""
# POST /inferences
def prepare_inference(raw_event, context):
request_id = context.aws_request_id
event = PrepareEvent(**json.loads(raw_event['body']))
_type = event.task_type
try:
extra_generate_types = ['extra-single-image', 'extra-batch-images', 'rembg']
simple_generate_types = ['txt2img', 'img2img']
if _type not in simple_generate_types and _type not in extra_generate_types:
return bad_request(
message=f'task type {event.task_type} should be in {extra_generate_types} or {simple_generate_types}'
)
# check if endpoint table for endpoint status and existence
inference_endpoint = _schedule_inference_endpoint(event.sagemaker_endpoint_name, event.user_id)
endpoint_name = inference_endpoint.endpoint_name
endpoint_id = inference_endpoint.EndpointDeploymentJobId
# generate param s3 location for upload
param_s3_key = f'{get_base_inference_param_s3_key(_type, request_id)}/api_param.json'
s3_location = f's3://{bucket_name}/{param_s3_key}'
presign_url = generate_presign_url(bucket_name, param_s3_key)
inference_job = InferenceJob(
InferenceJobId=request_id,
startTime=str(datetime.now()),
status='created',
taskType=_type,
owner_group_or_role=[event.user_id],
params={
'input_body_s3': s3_location,
'input_body_presign_url': presign_url,
'sagemaker_inference_endpoint_id': endpoint_id,
'sagemaker_inference_endpoint_name': endpoint_name,
},
)
resp = {
'inference': {
'id': request_id,
'type': _type,
'api_params_s3_location': s3_location,
'api_params_s3_upload_url': presign_url,
}
}
if _type in simple_generate_types:
# check if model(checkpoint) path(s) exists. return error if not
ckpts = []
ckpts_to_upload = []
for ckpt_type, names in event.models.items():
for name in names:
ckpt = _get_checkpoint_by_name(name, ckpt_type)
# todo: need check if user has permission for the model
if ckpt is None:
ckpts_to_upload.append({
'name': name,
'ckpt_type': ckpt_type
})
else:
ckpts.append(ckpt)
if len(ckpts_to_upload) > 0:
message = [f'checkpoint with name {c["name"]}, type {c["ckpt_type"]} is not found' for c in
ckpts_to_upload]
return bad_request(message=' '.join(message))
# create inference job with param location in ddb, status set to Created
used_models = {}
for ckpt in ckpts:
if ckpt.checkpoint_type not in used_models:
used_models[ckpt.checkpoint_type] = []
used_models[ckpt.checkpoint_type].append(
{
'id': ckpt.id,
'model_name': ckpt.checkpoint_names[0],
's3': ckpt.s3_location,
'type': ckpt.checkpoint_type
}
)
inference_job.params['used_models'] = used_models
resp['inference']['models'] = [{'id': ckpt.id, 'name': ckpt.checkpoint_names, 'type': ckpt.checkpoint_type}
for ckpt in ckpts]
ddb_service.put_items(inference_table_name, entries=inference_job.__dict__)
return ok(data=resp)
except Exception as e:
return bad_request(message=str(e))
# PUT /v2/inference/{inference_id}/run
def run_inference(event, _):
_filter = {}
inference_id = event['pathParameters']['id']
# get the inference job from ddb by job id
inference_raw = ddb_service.get_item(inference_table_name, {
'InferenceJobId': inference_id
})
assert inference_raw is not None and len(inference_raw) > 0
inference_job = InferenceJob(**inference_raw)
endpoint_name = inference_job.params['sagemaker_inference_endpoint_name']
models = {}
if 'used_models' in inference_job.params:
models = {
"space_free_size": 4e10,
**inference_job.params['used_models'],
}
payload = InvocationsRequest(
task=inference_job.taskType,
username="test",
models=models,
param_s3=inference_job.params['input_body_s3']
)
# start async inference
predictor = Predictor(endpoint_name)
initial_args = {"InvocationTimeoutSeconds": 3600}
predictor = AsyncPredictor(predictor, name=endpoint_name)
predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()
prediction = predictor.predict_async(data=payload.__dict__, initial_args=initial_args, inference_id=inference_id)
output_path = prediction.output_path
# update the ddb job status to 'inprogress' and save to ddb
inference_job.status = 'inprogress'
inference_job.params['output_path'] = output_path
ddb_service.put_items(inference_table_name, inference_job.__dict__)
data = {
'inference': {
'inference_id': inference_id,
'status': inference_job.status,
'endpoint_name': endpoint_name,
'output_path': output_path
}
}
return ok(data=data)
# GET /inferences?last_evaluated_key=xxx&limit=10&username=USER_NAME&name=SageMaker_Endpoint_Name&filter=key:value,key:value
def list_all_inference_jobs(event, ctx):
_filter = {}
parameters = event['queryStringParameters']
# todo: support pagination later
# limit = parameters['limit'] if 'limit' in parameters and parameters['limit'] else None
# last_evaluated_key = parameters['last_evaluated_key'] if 'last_evaluated_key' in parameters and parameters[
# 'last_evaluated_key'] else None
#
# if last_evaluated_key and isinstance(last_evaluated_key, str):
# last_evaluated_key = json.loads(last_evaluated_key)
# last_token = None
username = None
if parameters:
username = parameters['username'] if 'username' in parameters and parameters['username'] else None
scan_rows = ddb_service.scan(inference_table_name, filters=None)
results = []
user_roles = []
if username:
user_roles = get_user_roles(ddb_service=ddb_service, user_table_name=user_table, username=username)
for row in scan_rows:
inference = InferenceJob(**(ddb_service.deserialize(row)))
if username:
if check_user_permissions(inference.owner_group_or_role, user_roles, username):
results.append(inference.__dict__)
else:
results.append(inference.__dict__)
data = {
'inferences': results
}
return ok(data=data, decimal=True)
# POST /inference-api
def inference_l2(raw_event, context):
request_id = context.aws_request_id
if 'task_type' not in raw_event or 'sagemaker_endpoint_name' not in raw_event or 'models' not in raw_event:
return {
'status': 400,
'err': f'task_type, sagemaker_endpoint_name, models should be in the request body'
}
task_type = raw_event['task_type']
ep_name = raw_event['sagemaker_endpoint_name']
prepare_infer_event = {
'sagemaker_endpoint_name': raw_event['sagemaker_endpoint_name'],
'task_type': task_type,
'models': raw_event['models'],
'filters': {
'creator': 'l2api'
}
}
prepare_resp = prepare_inference(prepare_infer_event, context)
if 'inference' not in prepare_resp or 'api_params_s3_location' not in prepare_resp['inference']:
return {
'status': 500,
'err': f'fail to prepare inference for {request_id}, no s3 location is generated'
}
s3_location = prepare_resp['inference']['api_params_s3_location']
bucket, s3_file_key = split_s3_path(s3_location)
# merge the parameters with template
param_template = load_json_from_s3(bucket_name, 'template/inferenceTemplate.json')
merged_param = {**param_template, **raw_event}
upload_json_to_s3(bucket_name, s3_file_key, merged_param)
run_infer_resp = run_inference({
'pathStringParameters': {
'inference_id': request_id
}
}, context)
return {
'status': 200,
'inference': {
'inference_id': request_id,
'status': run_infer_resp['inference']['status'],
'output_path': run_infer_resp['inference']['output_path'],
'models': prepare_resp['inference']['models'],
'api_params_s3_location': s3_location,
'type': task_type,
'endpoint_name': ep_name
}
}
# fixme: this is a very expensive function
def _get_checkpoint_by_name(ckpt_name, model_type, status='Active') -> CheckPoint:
if model_type == 'VAE' and ckpt_name in ['None', 'Automatic']:
return CheckPoint(
id=model_type,
checkpoint_names=[ckpt_name],
s3_location='None',
checkpoint_type=model_type,
checkpoint_status=CheckPointStatus.Active,
timestamp=0,
)
checkpoint_raw = ddb_service.client.scan(
TableName=checkpoint_table,
FilterExpression='contains(checkpoint_names, :checkpointName) and checkpoint_type=:model_type and checkpoint_status=:checkpoint_status',
ExpressionAttributeValues={
':checkpointName': {'S': ckpt_name},
':model_type': {'S': model_type},
':checkpoint_status': {'S': status}
}
)
from common.ddb_service.types_ import ScanOutput
named_ = ScanOutput(**checkpoint_raw)
if checkpoint_raw is None or len(named_['Items']) == 0:
return None
return CheckPoint(**ddb_service.deserialize(named_['Items'][0]))
def get_base_inference_param_s3_key(_type: str, request_id: str) -> str:
return f'{_type}/infer_v2/{request_id}'
# currently only two scheduling ways: by endpoint name and by user
def _schedule_inference_endpoint(endpoint_name, user_id):
# fixme: endpoint is not indexed by name, and this is very expensive query
# fixme: we can either add index for endpoint name or make endpoint as the partition key
if endpoint_name:
sagemaker_endpoint_raw = ddb_service.scan(sagemaker_endpoint_table, filters={
'endpoint_name': endpoint_name
})[0]
sagemaker_endpoint_raw = ddb_service.deserialize(sagemaker_endpoint_raw)
if sagemaker_endpoint_raw is None or len(sagemaker_endpoint_raw) == 0:
raise Exception(f'sagemaker endpoint with name {endpoint_name} is not found')
if sagemaker_endpoint_raw['endpoint_status'] != 'InService':
raise Exception(f'sagemaker endpoint is not ready with status: {sagemaker_endpoint_raw["endpoint_status"]}')
return EndpointDeploymentJob(**sagemaker_endpoint_raw)
elif user_id:
sagemaker_endpoint_raws = ddb_service.scan(sagemaker_endpoint_table, filters=None)
user_roles = get_user_roles(ddb_service, user_table, user_id)
available_endpoints = []
for row in sagemaker_endpoint_raws:
endpoint = EndpointDeploymentJob(**ddb_service.deserialize(row))
if endpoint.endpoint_status != 'InService' or endpoint.status == 'deleted':
continue
if check_user_permissions(endpoint.owner_group_or_role, user_roles, user_id):
available_endpoints.append(endpoint)
if len(available_endpoints) == 0:
raise Exception(f'no available Endpoints for user "{user_id}"')
return random.choice(available_endpoints)