stable-diffusion-aws-extension/middleware_api/lambda/models/model_api.py

413 lines
15 KiB
Python

import dataclasses
import datetime
import json
import logging
import os
from dataclasses import dataclass
from typing import Any, Dict, Optional
import boto3
from botocore.exceptions import ClientError
from sagemaker import Predictor
from sagemaker.predictor_async import AsyncPredictor
from libs.common_tools import complete_multipart_upload, split_s3_path, DecimalEncoder
from libs.common_tools import get_base_model_s3_key, get_base_checkpoint_s3_key, \
batch_get_s3_multipart_signed_urls
from common.ddb_service.client import DynamoDbUtilsService
from common.response import ok, bad_request, internal_server_error
from libs.data_types import Model, CreateModelStatus, CheckPoint, CheckPointStatus, MultipartFileReq
from common.util import publish_msg
from libs.utils import get_permissions_by_username, get_user_roles, check_user_permissions
bucket_name = os.environ.get('S3_BUCKET')
model_table = os.environ.get('DYNAMODB_TABLE')
checkpoint_table = os.environ.get('CHECKPOINT_TABLE')
endpoint_name = os.environ.get('SAGEMAKER_ENDPOINT_NAME')
user_table = os.environ.get('MULTI_USER_TABLE')
success_topic_arn = os.environ.get('SUCCESS_TOPIC_ARN')
error_topic_arn = os.environ.get('ERROR_TOPIC_ARN')
user_topic_arn = os.environ.get('USER_TOPIC_ARN')
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
ddb_service = DynamoDbUtilsService(logger=logger)
@dataclasses.dataclass
class Event:
model_type: str
name: str
params: dict[str, Any]
filenames: [MultipartFileReq]
creator: str
checkpoint_id: Optional[str] = ""
# POST /model
def create_model_api(raw_event, context):
logger.info(json.dumps(raw_event))
request_id = context.aws_request_id
event = Event(**json.loads(raw_event['body']))
_type = event.model_type
try:
# check if roles has already linked to an endpoint?
creator_permissions = get_permissions_by_username(ddb_service, user_table, event.creator)
if 'train' not in creator_permissions \
or ('all' not in creator_permissions['train'] and 'create' not in creator_permissions['train']):
return bad_request(message=f'user {event.creator} has not permission to create a train job')
user_roles = get_user_roles(ddb_service, user_table, event.creator)
# todo: check if duplicated name and new_model_name only for Completed and Model
if not event.checkpoint_id and len(event.filenames) == 0:
return bad_request(message='either checkpoint_id or filenames need to be provided')
base_key = get_base_model_s3_key(_type, event.name, request_id)
timestamp = datetime.datetime.now().timestamp()
multiparts_resp = {}
if not event.checkpoint_id:
checkpoint_base_key = get_base_checkpoint_s3_key(_type, event.name, request_id)
presign_url_map = batch_get_s3_multipart_signed_urls(
bucket_name=bucket_name,
base_key=checkpoint_base_key,
filenames=event.filenames
)
filenames_only = []
for f in event.filenames:
file = MultipartFileReq(**f)
filenames_only.append(file.filename)
checkpoint_params = {'created': str(datetime.datetime.now()), 'multipart_upload': {
}}
for key, val in presign_url_map.items():
checkpoint_params['multipart_upload'][key] = {
'upload_id': val['upload_id'],
'bucket': val['bucket'],
'key': val['key'],
}
multiparts_resp[key] = val['s3_signed_urls']
checkpoint = CheckPoint(
id=request_id,
checkpoint_type=event.model_type,
s3_location=f's3://{bucket_name}/{get_base_checkpoint_s3_key(_type, event.name, request_id)}',
checkpoint_names=filenames_only,
checkpoint_status=CheckPointStatus.Initial,
params=checkpoint_params,
timestamp=timestamp,
allowed_roles_or_users=user_roles,
)
ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__)
checkpoint_id = checkpoint.id
else:
raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={
'id': event.checkpoint_id,
})
if raw_checkpoint is None:
return bad_request(message=f'create model ckpt with id {event.checkpoint_id} is not found')
checkpoint = CheckPoint(**raw_checkpoint)
if checkpoint.checkpoint_status != CheckPointStatus.Active:
return bad_request(message=f'checkpoint with id ({checkpoint.id}) is not Active to use')
checkpoint_id = checkpoint.id
if checkpoint.allowed_roles_or_users:
allowed_to_use = False
for role in checkpoint.allowed_roles_or_users:
if role in user_roles or '*' == role:
allowed_to_use = True
break
if not allowed_to_use:
return bad_request(
message=f'checkpoint with id ({checkpoint.id}) is not allowed to use by user {event.creator}')
model_job = Model(
id=request_id,
name=event.name,
output_s3_location=f's3://{bucket_name}/{base_key}/output',
checkpoint_id=checkpoint_id,
model_type=_type,
job_status=CreateModelStatus.Initial,
params=event.params,
timestamp=timestamp,
allowed_roles_or_users=[event.creator]
)
ddb_service.put_items(table=model_table, entries=model_job.__dict__)
except ClientError as e:
logger.error(e)
return internal_server_error(message=str(e))
data = {
'job': {
'id': model_job.id,
'status': model_job.job_status.value,
's3_base': checkpoint.s3_location,
'model_type': model_job.model_type,
'params': model_job.params # not safe if not json serializable type
},
's3PresignUrl': multiparts_resp
}
return ok(data=data)
# GET /models
def list_all_models_api(event, context):
_filter = {}
parameters = event['queryStringParameters']
if parameters:
if 'types' in parameters and len(parameters['types']) > 0:
_filter['model_type'] = parameters['types']
if 'status' in parameters and len(parameters['status']) > 0:
_filter['job_status'] = parameters['status']
resp = ddb_service.scan(table=model_table, filters=_filter)
if resp is None or len(resp) == 0:
return ok(data={'models': []})
models = []
try:
requestor_name = event['requestContext']['authorizer']['username']
requestor_permissions = get_permissions_by_username(ddb_service, user_table, requestor_name)
requestor_roles = get_user_roles(ddb_service=ddb_service, user_table_name=user_table, username=requestor_name)
if 'train' not in requestor_permissions or \
('all' not in requestor_permissions['train'] and 'list' not in requestor_permissions['train']):
return bad_request(message='user has no permission to train')
for r in resp:
model = Model(**(ddb_service.deserialize(r)))
model_dto = {
'id': model.id,
'model_name': model.name,
'created': model.timestamp,
'params': model.params,
'status': model.job_status.value,
'output_s3_location': model.output_s3_location
}
if model.allowed_roles_or_users and check_user_permissions(model.allowed_roles_or_users, requestor_roles,
requestor_name):
models.append(model_dto)
elif not model.allowed_roles_or_users and \
'user' in requestor_permissions and \
'all' in requestor_permissions['user']:
# superuser can view the legacy data
models.append(model_dto)
return ok(data={'models': models}, decimal=True)
except Exception as e:
logger.error(e)
return bad_request(message=str(e))
@dataclass
class PutModelEvent:
status: str
multi_parts_tags: Dict[str, Any]
# PUT /model
def update_model_job_api(raw_event, context):
logger.info(json.dumps(raw_event))
event = PutModelEvent(**json.loads(raw_event['body']))
model_id = raw_event['pathParameters']['id']
try:
raw_model_job = ddb_service.get_item(table=model_table, key_values={'id': model_id})
if raw_model_job is None:
return bad_request(message=f'create model with id {model_id} is not found')
model_job = Model(**raw_model_job)
raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={
'id': model_job.checkpoint_id,
})
if raw_checkpoint is None:
return bad_request(message=f'create model ckpt with id {model_id} is not found')
ckpt = CheckPoint(**raw_checkpoint)
if ckpt.checkpoint_status == ckpt.checkpoint_status.Initial:
complete_multipart_upload(ckpt, event.multi_parts_tags)
ddb_service.update_item(
table=checkpoint_table,
key={'id': ckpt.id},
field_name='checkpoint_status',
value=CheckPointStatus.Active.value
)
resp = _exec(model_job, CreateModelStatus[event.status])
ddb_service.update_item(
table=model_table,
key={'id': model_job.id},
field_name='job_status',
value=event.status
)
return ok(data=resp)
except ClientError as e:
logger.error(e)
return internal_server_error(message=str(e))
# SNS callback
def process_result(event, context):
records = event['Records']
for record in records:
msg_str = record['Sns']['Message']
print(msg_str)
msg = json.loads(msg_str)
inference_id = msg['inferenceId']
model_job_raw = ddb_service.get_item(table=model_table, key_values={
'id': inference_id
})
if model_job_raw is None:
return {
'statusCode': '500',
'error': f'id with {inference_id} not found'
}
model = Model(**model_job_raw)
if record['Sns']['TopicArn'] == success_topic_arn:
resp_location = msg['responseParameters']['outputLocation']
bucket, key = split_s3_path(resp_location)
content = get_object(bucket=bucket, key=key)
if content['statusCode'] != 200:
ddb_service.update_item(
table=model_table,
key={'id': model.id},
field_name='job_status',
value=CreateModelStatus.Fail.value
)
publish_msg(
topic_arn=user_topic_arn,
subject=f'Create Model Job {model.name}: {model.id} failed',
msg='to be done'
) # todo: find out msg
return
msgs = content['message']
model.params['resp'] = {}
for key, val in msgs.items():
model.params['resp'][key] = val
ddb_service.update_item(
table=model_table,
key={'id': inference_id},
field_name='job_status',
value=CreateModelStatus.Complete.value
)
params = model_job_raw['params']
params['resp']['s3_output_location'] = f'{bucket_name}/{model.model_type}/{model.name}.tar'
ddb_service.update_item(
table=model_table,
key={'id': inference_id},
field_name='params',
value=params
)
publish_msg(
topic_arn=user_topic_arn,
subject=f'Create Model Job {model.name}: {model.id} success',
msg=f'model {model.name}: {model.id} is ready to use'
) # todo: find out msg
if record['Sns']['TopicArn'] == error_topic_arn:
ddb_service.update_item(
table=model_table,
key={'id': inference_id},
field_name='job_status',
value=CreateModelStatus.Fail.value
)
publish_msg(
topic_arn=user_topic_arn,
subject=f'Create Model Job {model.name}: {model.id} failed',
msg='to be done'
) # todo: find out msg
return {
'statusCode': 200,
'msg': f'finished events {event}'
}
def get_object(bucket: str, key: str):
s3_client = boto3.client('s3')
data = s3_client.get_object(Bucket=bucket, Key=key)
content = json.load(data['Body'])
return content
def _exec(model_job: Model, action: CreateModelStatus):
if model_job.job_status == CreateModelStatus.Creating and \
(action != CreateModelStatus.Fail or action != CreateModelStatus.Complete):
raise Exception(f'model creation job is currently under progress, so cannot be updated')
if action == CreateModelStatus.Creating:
model_job.job_status = action
raw_chkpt = ddb_service.get_item(table=checkpoint_table, key_values={'id': model_job.checkpoint_id})
if raw_chkpt is None:
return {
'statusCode': 200,
'error': f'model related checkpoint with id {model_job.checkpoint_id} is not found'
}
checkpoint = CheckPoint(**raw_chkpt)
checkpoint.checkpoint_status = CheckPointStatus.Active
ddb_service.update_item(
table=checkpoint_table,
key={'id': checkpoint.id},
field_name='checkpoint_status',
value=CheckPointStatus.Active.value
)
return create_sagemaker_inference(job=model_job, checkpoint=checkpoint)
elif action == CreateModelStatus.Initial:
raise Exception('please create a new model creation job for this,'
f' not allowed overwrite old model creation job')
else:
# todo: other action
raise NotImplemented
def create_sagemaker_inference(job: Model, checkpoint: CheckPoint):
payload = {
"task": "db-create-model", # router
"param_s3": "",
"db_create_model_payload": json.dumps({
"s3_output_path": job.output_s3_location, # output object
"s3_input_path": checkpoint.s3_location,
"ckpt_names": checkpoint.checkpoint_names,
"param": job.params,
"job_id": job.id
}, cls=DecimalEncoder),
}
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
predictor = Predictor(endpoint_name)
predictor = AsyncPredictor(predictor, name=job.id)
predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()
prediction = predictor.predict_async(data=payload, inference_id=job.id)
output_path = prediction.output_path
return {
'statusCode': 200,
'job': {
'output_path': output_path,
'id': job.id,
'endpointName': endpoint_name,
'jobStatus': job.job_status.value,
'jobType': job.model_type
}
}