347 lines
13 KiB
Python
347 lines
13 KiB
Python
import dataclasses
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
from datetime import datetime
|
|
from typing import List, Any, Optional
|
|
|
|
from sagemaker import Predictor
|
|
from sagemaker.deserializers import JSONDeserializer
|
|
from sagemaker.predictor_async import AsyncPredictor
|
|
from sagemaker.serializers import JSONSerializer
|
|
|
|
from common.ddb_service.client import DynamoDbUtilsService
|
|
from common.response import ok, bad_request
|
|
from common.util import generate_presign_url, load_json_from_s3, upload_json_to_s3, split_s3_path
|
|
from libs.data_types import CheckPoint, CheckPointStatus
|
|
from libs.data_types import InferenceJob, InvocationsRequest, EndpointDeploymentJob
|
|
from libs.utils import get_user_roles, check_user_permissions
|
|
|
|
bucket_name = os.environ.get('S3_BUCKET')
|
|
checkpoint_table = os.environ.get('CHECKPOINT_TABLE')
|
|
sagemaker_endpoint_table = os.environ.get('DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME')
|
|
inference_table_name = os.environ.get('INFERENCE_JOB_TABLE')
|
|
user_table = os.environ.get('MULTI_USER_TABLE')
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
|
|
|
|
ddb_service = DynamoDbUtilsService(logger=logger)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class PrepareEvent:
|
|
task_type: str
|
|
models: dict[str, List[str]] # [checkpoint_type: names] this is same as checkpoint if confused
|
|
filters: dict[str, Any]
|
|
sagemaker_endpoint_name: Optional[str] = ""
|
|
user_id: Optional[str] = ""
|
|
|
|
|
|
# POST /inferences
|
|
def prepare_inference(raw_event, context):
|
|
request_id = context.aws_request_id
|
|
event = PrepareEvent(**json.loads(raw_event['body']))
|
|
_type = event.task_type
|
|
|
|
try:
|
|
extra_generate_types = ['extra-single-image', 'extra-batch-images', 'rembg']
|
|
simple_generate_types = ['txt2img', 'img2img']
|
|
|
|
if _type not in simple_generate_types and _type not in extra_generate_types:
|
|
return bad_request(
|
|
message=f'task type {event.task_type} should be in {extra_generate_types} or {simple_generate_types}'
|
|
)
|
|
|
|
# check if endpoint table for endpoint status and existence
|
|
inference_endpoint = _schedule_inference_endpoint(event.sagemaker_endpoint_name, event.user_id)
|
|
endpoint_name = inference_endpoint.endpoint_name
|
|
endpoint_id = inference_endpoint.EndpointDeploymentJobId
|
|
|
|
# generate param s3 location for upload
|
|
param_s3_key = f'{get_base_inference_param_s3_key(_type, request_id)}/api_param.json'
|
|
s3_location = f's3://{bucket_name}/{param_s3_key}'
|
|
presign_url = generate_presign_url(bucket_name, param_s3_key)
|
|
inference_job = InferenceJob(
|
|
InferenceJobId=request_id,
|
|
startTime=str(datetime.now()),
|
|
status='created',
|
|
taskType=_type,
|
|
owner_group_or_role=[event.user_id],
|
|
params={
|
|
'input_body_s3': s3_location,
|
|
'input_body_presign_url': presign_url,
|
|
'sagemaker_inference_endpoint_id': endpoint_id,
|
|
'sagemaker_inference_endpoint_name': endpoint_name,
|
|
},
|
|
)
|
|
resp = {
|
|
'inference': {
|
|
'id': request_id,
|
|
'type': _type,
|
|
'api_params_s3_location': s3_location,
|
|
'api_params_s3_upload_url': presign_url,
|
|
}
|
|
}
|
|
|
|
if _type in simple_generate_types:
|
|
# check if model(checkpoint) path(s) exists. return error if not
|
|
ckpts = []
|
|
ckpts_to_upload = []
|
|
for ckpt_type, names in event.models.items():
|
|
for name in names:
|
|
ckpt = _get_checkpoint_by_name(name, ckpt_type)
|
|
# todo: need check if user has permission for the model
|
|
if ckpt is None:
|
|
ckpts_to_upload.append({
|
|
'name': name,
|
|
'ckpt_type': ckpt_type
|
|
})
|
|
else:
|
|
ckpts.append(ckpt)
|
|
|
|
if len(ckpts_to_upload) > 0:
|
|
message = [f'checkpoint with name {c["name"]}, type {c["ckpt_type"]} is not found' for c in
|
|
ckpts_to_upload]
|
|
return bad_request(message=' '.join(message))
|
|
|
|
# create inference job with param location in ddb, status set to Created
|
|
used_models = {}
|
|
for ckpt in ckpts:
|
|
if ckpt.checkpoint_type not in used_models:
|
|
used_models[ckpt.checkpoint_type] = []
|
|
|
|
used_models[ckpt.checkpoint_type].append(
|
|
{
|
|
'id': ckpt.id,
|
|
'model_name': ckpt.checkpoint_names[0],
|
|
's3': ckpt.s3_location,
|
|
'type': ckpt.checkpoint_type
|
|
}
|
|
)
|
|
|
|
inference_job.params['used_models'] = used_models
|
|
resp['inference']['models'] = [{'id': ckpt.id, 'name': ckpt.checkpoint_names, 'type': ckpt.checkpoint_type}
|
|
for ckpt in ckpts]
|
|
|
|
ddb_service.put_items(inference_table_name, entries=inference_job.__dict__)
|
|
return ok(data=resp)
|
|
except Exception as e:
|
|
return bad_request(message=str(e))
|
|
|
|
|
|
# PUT /v2/inference/{inference_id}/run
|
|
def run_inference(event, _):
|
|
_filter = {}
|
|
inference_id = event['pathParameters']['id']
|
|
|
|
# get the inference job from ddb by job id
|
|
inference_raw = ddb_service.get_item(inference_table_name, {
|
|
'InferenceJobId': inference_id
|
|
})
|
|
|
|
assert inference_raw is not None and len(inference_raw) > 0
|
|
inference_job = InferenceJob(**inference_raw)
|
|
endpoint_name = inference_job.params['sagemaker_inference_endpoint_name']
|
|
models = {}
|
|
if 'used_models' in inference_job.params:
|
|
models = {
|
|
"space_free_size": 4e10,
|
|
**inference_job.params['used_models'],
|
|
}
|
|
|
|
payload = InvocationsRequest(
|
|
task=inference_job.taskType,
|
|
username="test",
|
|
models=models,
|
|
param_s3=inference_job.params['input_body_s3']
|
|
)
|
|
|
|
# start async inference
|
|
predictor = Predictor(endpoint_name)
|
|
initial_args = {"InvocationTimeoutSeconds": 3600}
|
|
predictor = AsyncPredictor(predictor, name=endpoint_name)
|
|
predictor.serializer = JSONSerializer()
|
|
predictor.deserializer = JSONDeserializer()
|
|
prediction = predictor.predict_async(data=payload.__dict__, initial_args=initial_args, inference_id=inference_id)
|
|
output_path = prediction.output_path
|
|
|
|
# update the ddb job status to 'inprogress' and save to ddb
|
|
inference_job.status = 'inprogress'
|
|
inference_job.params['output_path'] = output_path
|
|
ddb_service.put_items(inference_table_name, inference_job.__dict__)
|
|
|
|
data = {
|
|
'inference': {
|
|
'inference_id': inference_id,
|
|
'status': inference_job.status,
|
|
'endpoint_name': endpoint_name,
|
|
'output_path': output_path
|
|
}
|
|
}
|
|
|
|
return ok(data=data)
|
|
|
|
|
|
# GET /inferences?last_evaluated_key=xxx&limit=10&username=USER_NAME&name=SageMaker_Endpoint_Name&filter=key:value,key:value
|
|
def list_all_inference_jobs(event, ctx):
|
|
_filter = {}
|
|
|
|
parameters = event['queryStringParameters']
|
|
|
|
# todo: support pagination later
|
|
# limit = parameters['limit'] if 'limit' in parameters and parameters['limit'] else None
|
|
# last_evaluated_key = parameters['last_evaluated_key'] if 'last_evaluated_key' in parameters and parameters[
|
|
# 'last_evaluated_key'] else None
|
|
#
|
|
# if last_evaluated_key and isinstance(last_evaluated_key, str):
|
|
# last_evaluated_key = json.loads(last_evaluated_key)
|
|
# last_token = None
|
|
|
|
username = None
|
|
if parameters:
|
|
username = parameters['username'] if 'username' in parameters and parameters['username'] else None
|
|
|
|
scan_rows = ddb_service.scan(inference_table_name, filters=None)
|
|
results = []
|
|
user_roles = []
|
|
if username:
|
|
user_roles = get_user_roles(ddb_service=ddb_service, user_table_name=user_table, username=username)
|
|
|
|
for row in scan_rows:
|
|
inference = InferenceJob(**(ddb_service.deserialize(row)))
|
|
if username:
|
|
if check_user_permissions(inference.owner_group_or_role, user_roles, username):
|
|
results.append(inference.__dict__)
|
|
else:
|
|
results.append(inference.__dict__)
|
|
|
|
data = {
|
|
'inferences': results
|
|
}
|
|
|
|
return ok(data=data, decimal=True)
|
|
|
|
|
|
# POST /inference-api
|
|
def inference_l2(raw_event, context):
|
|
request_id = context.aws_request_id
|
|
if 'task_type' not in raw_event or 'sagemaker_endpoint_name' not in raw_event or 'models' not in raw_event:
|
|
return {
|
|
'status': 400,
|
|
'err': f'task_type, sagemaker_endpoint_name, models should be in the request body'
|
|
}
|
|
task_type = raw_event['task_type']
|
|
ep_name = raw_event['sagemaker_endpoint_name']
|
|
|
|
prepare_infer_event = {
|
|
'sagemaker_endpoint_name': raw_event['sagemaker_endpoint_name'],
|
|
'task_type': task_type,
|
|
'models': raw_event['models'],
|
|
'filters': {
|
|
'creator': 'l2api'
|
|
}
|
|
}
|
|
prepare_resp = prepare_inference(prepare_infer_event, context)
|
|
|
|
if 'inference' not in prepare_resp or 'api_params_s3_location' not in prepare_resp['inference']:
|
|
return {
|
|
'status': 500,
|
|
'err': f'fail to prepare inference for {request_id}, no s3 location is generated'
|
|
}
|
|
|
|
s3_location = prepare_resp['inference']['api_params_s3_location']
|
|
bucket, s3_file_key = split_s3_path(s3_location)
|
|
|
|
# merge the parameters with template
|
|
param_template = load_json_from_s3(bucket_name, 'template/inferenceTemplate.json')
|
|
merged_param = {**param_template, **raw_event}
|
|
upload_json_to_s3(bucket_name, s3_file_key, merged_param)
|
|
|
|
run_infer_resp = run_inference({
|
|
'pathStringParameters': {
|
|
'inference_id': request_id
|
|
}
|
|
}, context)
|
|
|
|
return {
|
|
'status': 200,
|
|
'inference': {
|
|
'inference_id': request_id,
|
|
'status': run_infer_resp['inference']['status'],
|
|
'output_path': run_infer_resp['inference']['output_path'],
|
|
'models': prepare_resp['inference']['models'],
|
|
'api_params_s3_location': s3_location,
|
|
'type': task_type,
|
|
'endpoint_name': ep_name
|
|
}
|
|
}
|
|
|
|
|
|
# fixme: this is a very expensive function
|
|
def _get_checkpoint_by_name(ckpt_name, model_type, status='Active') -> CheckPoint:
|
|
if model_type == 'VAE' and ckpt_name in ['None', 'Automatic']:
|
|
return CheckPoint(
|
|
id=model_type,
|
|
checkpoint_names=[ckpt_name],
|
|
s3_location='None',
|
|
checkpoint_type=model_type,
|
|
checkpoint_status=CheckPointStatus.Active,
|
|
timestamp=0,
|
|
)
|
|
|
|
checkpoint_raw = ddb_service.client.scan(
|
|
TableName=checkpoint_table,
|
|
FilterExpression='contains(checkpoint_names, :checkpointName) and checkpoint_type=:model_type and checkpoint_status=:checkpoint_status',
|
|
ExpressionAttributeValues={
|
|
':checkpointName': {'S': ckpt_name},
|
|
':model_type': {'S': model_type},
|
|
':checkpoint_status': {'S': status}
|
|
}
|
|
)
|
|
|
|
from common.ddb_service.types_ import ScanOutput
|
|
named_ = ScanOutput(**checkpoint_raw)
|
|
if checkpoint_raw is None or len(named_['Items']) == 0:
|
|
return None
|
|
|
|
return CheckPoint(**ddb_service.deserialize(named_['Items'][0]))
|
|
|
|
|
|
def get_base_inference_param_s3_key(_type: str, request_id: str) -> str:
|
|
return f'{_type}/infer_v2/{request_id}'
|
|
|
|
|
|
# currently only two scheduling ways: by endpoint name and by user
|
|
def _schedule_inference_endpoint(endpoint_name, user_id):
|
|
# fixme: endpoint is not indexed by name, and this is very expensive query
|
|
# fixme: we can either add index for endpoint name or make endpoint as the partition key
|
|
if endpoint_name:
|
|
sagemaker_endpoint_raw = ddb_service.scan(sagemaker_endpoint_table, filters={
|
|
'endpoint_name': endpoint_name
|
|
})[0]
|
|
sagemaker_endpoint_raw = ddb_service.deserialize(sagemaker_endpoint_raw)
|
|
if sagemaker_endpoint_raw is None or len(sagemaker_endpoint_raw) == 0:
|
|
raise Exception(f'sagemaker endpoint with name {endpoint_name} is not found')
|
|
|
|
if sagemaker_endpoint_raw['endpoint_status'] != 'InService':
|
|
raise Exception(f'sagemaker endpoint is not ready with status: {sagemaker_endpoint_raw["endpoint_status"]}')
|
|
return EndpointDeploymentJob(**sagemaker_endpoint_raw)
|
|
elif user_id:
|
|
sagemaker_endpoint_raws = ddb_service.scan(sagemaker_endpoint_table, filters=None)
|
|
user_roles = get_user_roles(ddb_service, user_table, user_id)
|
|
available_endpoints = []
|
|
for row in sagemaker_endpoint_raws:
|
|
endpoint = EndpointDeploymentJob(**ddb_service.deserialize(row))
|
|
if endpoint.endpoint_status != 'InService' or endpoint.status == 'deleted':
|
|
continue
|
|
|
|
if check_user_permissions(endpoint.owner_group_or_role, user_roles, user_id):
|
|
available_endpoints.append(endpoint)
|
|
|
|
if len(available_endpoints) == 0:
|
|
raise Exception(f'no available Endpoints for user "{user_id}"')
|
|
|
|
return random.choice(available_endpoints)
|