227 lines
9.5 KiB
Python
227 lines
9.5 KiB
Python
import dataclasses
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
from datetime import datetime
|
|
from typing import List, Any, Optional
|
|
|
|
from common.const import PERMISSION_INFERENCE_ALL, PERMISSION_INFERENCE_CREATE
|
|
from common.ddb_service.client import DynamoDbUtilsService
|
|
from common.response import bad_request, created
|
|
from common.util import generate_presign_url
|
|
from start_inference_job import start_inference_job
|
|
from libs.data_types import CheckPoint, CheckPointStatus
|
|
from libs.data_types import InferenceJob, EndpointDeploymentJob
|
|
from libs.enums import EndpointStatus
|
|
from libs.utils import get_user_roles, check_user_permissions, permissions_check, response_error, log_execution_time
|
|
|
|
bucket_name = os.environ.get('S3_BUCKET_NAME')
|
|
checkpoint_table = os.environ.get('CHECKPOINT_TABLE')
|
|
sagemaker_endpoint_table = os.environ.get('DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME')
|
|
inference_table_name = os.environ.get('INFERENCE_JOB_TABLE')
|
|
user_table = os.environ.get('MULTI_USER_TABLE')
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
|
|
|
|
ddb_service = DynamoDbUtilsService(logger=logger)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class CreateInferenceEvent:
|
|
task_type: str
|
|
models: dict[str, List[str]] # [checkpoint_type: names] this is the same as checkpoint if confused
|
|
sagemaker_endpoint_name: Optional[str] = ""
|
|
inference_type: Optional[str] = None
|
|
# todo user_id is not used in this lambda, but we need to keep it for the compatibility with the old code
|
|
filters: dict[str, Any] = None
|
|
user_id: Optional[str] = ""
|
|
payload_string: Optional[str] = None
|
|
|
|
|
|
# POST /inferences
|
|
def handler(raw_event, context):
|
|
try:
|
|
logger.info(json.dumps(raw_event, default=str))
|
|
request_id = context.aws_request_id
|
|
logger.info(json.dumps(json.loads(raw_event['body'])))
|
|
event = CreateInferenceEvent(**json.loads(raw_event['body']))
|
|
|
|
if event.payload_string:
|
|
try:
|
|
json.loads(event.payload_string)
|
|
except json.JSONDecodeError:
|
|
return bad_request(message='payload_string must be valid json string')
|
|
|
|
username = permissions_check(raw_event, [PERMISSION_INFERENCE_ALL, PERMISSION_INFERENCE_CREATE])
|
|
|
|
_type = event.task_type
|
|
extra_generate_types = ['extra-single-image', 'extra-batch-images', 'rembg']
|
|
simple_generate_types = ['txt2img', 'img2img']
|
|
|
|
if _type not in simple_generate_types and _type not in extra_generate_types:
|
|
return bad_request(
|
|
message=f'task type {event.task_type} should be in {extra_generate_types} or {simple_generate_types}'
|
|
)
|
|
|
|
# check if endpoint table for endpoint status and existence
|
|
inference_endpoint = _schedule_inference_endpoint(event.sagemaker_endpoint_name, event.inference_type,
|
|
username)
|
|
endpoint_name = inference_endpoint.endpoint_name
|
|
endpoint_id = inference_endpoint.EndpointDeploymentJobId
|
|
instance_type = inference_endpoint.instance_type
|
|
|
|
# generate param s3 location for upload
|
|
param_s3_key = f'{get_base_inference_param_s3_key(_type, request_id)}/api_param.json'
|
|
s3_location = f's3://{bucket_name}/{param_s3_key}'
|
|
presign_url = None
|
|
if event.payload_string is None:
|
|
presign_url = generate_presign_url(bucket_name, param_s3_key)
|
|
inference_job = InferenceJob(
|
|
InferenceJobId=request_id,
|
|
createTime=str(datetime.now()),
|
|
startTime=str(datetime.now()),
|
|
status='created',
|
|
taskType=_type,
|
|
inference_type=event.inference_type,
|
|
owner_group_or_role=[username],
|
|
payload_string=event.payload_string,
|
|
params={
|
|
'input_body_s3': s3_location,
|
|
'input_body_presign_url': presign_url,
|
|
'sagemaker_inference_endpoint_id': endpoint_id,
|
|
'sagemaker_inference_instance_type': instance_type,
|
|
'sagemaker_inference_endpoint_name': endpoint_name,
|
|
},
|
|
)
|
|
resp = {
|
|
'inference': {
|
|
'id': request_id,
|
|
'type': _type,
|
|
'api_params_s3_location': s3_location,
|
|
'api_params_s3_upload_url': presign_url,
|
|
}
|
|
}
|
|
|
|
if _type in simple_generate_types:
|
|
# check if model(checkpoint) path(s) exists. return error if not
|
|
ckpts = []
|
|
ckpts_to_upload = []
|
|
for ckpt_type, names in event.models.items():
|
|
for name in names:
|
|
ckpt = _get_checkpoint_by_name(name, ckpt_type)
|
|
# todo: need check if user has permission for the model
|
|
if ckpt is None:
|
|
ckpts_to_upload.append({
|
|
'name': name,
|
|
'ckpt_type': ckpt_type
|
|
})
|
|
else:
|
|
ckpts.append(ckpt)
|
|
|
|
if len(ckpts_to_upload) > 0:
|
|
message = [f'checkpoint with name {c["name"]}, type {c["ckpt_type"]} is not found' for c in
|
|
ckpts_to_upload]
|
|
return bad_request(message=' '.join(message))
|
|
|
|
# create inference job with param location in ddb, status set to Created
|
|
used_models = {}
|
|
for ckpt in ckpts:
|
|
if ckpt.checkpoint_type not in used_models:
|
|
used_models[ckpt.checkpoint_type] = []
|
|
|
|
used_models[ckpt.checkpoint_type].append(
|
|
{
|
|
'id': ckpt.id,
|
|
'model_name': ckpt.checkpoint_names[0],
|
|
's3': ckpt.s3_location,
|
|
'type': ckpt.checkpoint_type
|
|
}
|
|
)
|
|
|
|
inference_job.params['used_models'] = used_models
|
|
resp['inference']['models'] = [{'id': ckpt.id, 'name': ckpt.checkpoint_names, 'type': ckpt.checkpoint_type}
|
|
for ckpt in ckpts]
|
|
|
|
ddb_service.put_items(inference_table_name, entries=inference_job.__dict__)
|
|
|
|
if event.payload_string:
|
|
return start_inference_job(inference_job, username)
|
|
|
|
return created(data=resp)
|
|
except Exception as e:
|
|
return response_error(e)
|
|
|
|
|
|
# fixme: this is a very expensive function
|
|
@log_execution_time
|
|
def _get_checkpoint_by_name(ckpt_name, model_type, status='Active') -> CheckPoint:
|
|
if model_type == 'VAE' and ckpt_name in ['None', 'Automatic']:
|
|
return CheckPoint(
|
|
id=model_type,
|
|
checkpoint_names=[ckpt_name],
|
|
s3_location='None',
|
|
checkpoint_type=model_type,
|
|
checkpoint_status=CheckPointStatus.Active,
|
|
timestamp=0,
|
|
)
|
|
|
|
checkpoint_raw = ddb_service.client.scan(
|
|
TableName=checkpoint_table,
|
|
FilterExpression='contains(checkpoint_names, :checkpointName) and checkpoint_type=:model_type and checkpoint_status=:checkpoint_status',
|
|
ExpressionAttributeValues={
|
|
':checkpointName': {'S': ckpt_name},
|
|
':model_type': {'S': model_type},
|
|
':checkpoint_status': {'S': status}
|
|
}
|
|
)
|
|
|
|
from common.ddb_service.types_ import ScanOutput
|
|
named_ = ScanOutput(**checkpoint_raw)
|
|
if checkpoint_raw is None or len(named_['Items']) == 0:
|
|
return None
|
|
|
|
return CheckPoint(**ddb_service.deserialize(named_['Items'][0]))
|
|
|
|
|
|
def get_base_inference_param_s3_key(_type: str, request_id: str) -> str:
|
|
return f'{_type}/infer_v2/{request_id}'
|
|
|
|
|
|
# currently only two scheduling ways: by endpoint name and by user
|
|
@log_execution_time
|
|
def _schedule_inference_endpoint(endpoint_name, inference_type, user_id):
|
|
# fixme: endpoint is not indexed by name, and this is very expensive query
|
|
# fixme: we can either add index for endpoint name or make endpoint as the partition key
|
|
if endpoint_name:
|
|
sagemaker_endpoint_raw = ddb_service.scan(sagemaker_endpoint_table, filters={
|
|
'endpoint_name': endpoint_name
|
|
})[0]
|
|
sagemaker_endpoint_raw = ddb_service.deserialize(sagemaker_endpoint_raw)
|
|
if sagemaker_endpoint_raw is None or len(sagemaker_endpoint_raw) == 0:
|
|
raise Exception(f'sagemaker endpoint with name {endpoint_name} is not found')
|
|
|
|
if sagemaker_endpoint_raw['endpoint_status'] != 'InService':
|
|
raise Exception(f'sagemaker endpoint is not ready with status: {sagemaker_endpoint_raw["endpoint_status"]}')
|
|
return EndpointDeploymentJob(**sagemaker_endpoint_raw)
|
|
elif user_id:
|
|
sagemaker_endpoint_raws = ddb_service.scan(sagemaker_endpoint_table, filters=None)
|
|
user_roles = get_user_roles(ddb_service, user_table, user_id)
|
|
available_endpoints = []
|
|
for row in sagemaker_endpoint_raws:
|
|
endpoint = EndpointDeploymentJob(**ddb_service.deserialize(row))
|
|
if endpoint.status == 'deleted':
|
|
continue
|
|
if endpoint.endpoint_status != EndpointStatus.UPDATING.value and endpoint.endpoint_status != EndpointStatus.IN_SERVICE.value:
|
|
continue
|
|
if endpoint.endpoint_type != inference_type:
|
|
continue
|
|
if check_user_permissions(endpoint.owner_group_or_role, user_roles, user_id):
|
|
available_endpoints.append(endpoint)
|
|
|
|
if len(available_endpoints) == 0:
|
|
raise Exception(f'no available {inference_type} endpoints for user "{user_id}"')
|
|
|
|
return random.choice(available_endpoints)
|