148 lines
5.7 KiB
Python
148 lines
5.7 KiB
Python
import dataclasses
|
|
import datetime
|
|
import json
|
|
import logging
|
|
import os
|
|
from typing import Any, Optional
|
|
|
|
from botocore.exceptions import ClientError
|
|
|
|
from common.ddb_service.client import DynamoDbUtilsService
|
|
from common.response import bad_request, internal_server_error, created
|
|
from libs.common_tools import get_base_model_s3_key, get_base_checkpoint_s3_key, \
|
|
batch_get_s3_multipart_signed_urls
|
|
from libs.data_types import Model, CreateModelStatus, CheckPoint, CheckPointStatus, MultipartFileReq
|
|
from libs.utils import get_permissions_by_username, get_user_roles
|
|
|
|
bucket_name = os.environ.get('S3_BUCKET')
|
|
model_table = os.environ.get('DYNAMODB_TABLE')
|
|
checkpoint_table = os.environ.get('CHECKPOINT_TABLE')
|
|
|
|
user_table = os.environ.get('MULTI_USER_TABLE')
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
|
|
|
|
ddb_service = DynamoDbUtilsService(logger=logger)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Event:
|
|
model_type: str
|
|
name: str
|
|
params: dict[str, Any]
|
|
filenames: [MultipartFileReq]
|
|
creator: str
|
|
checkpoint_id: Optional[str] = ""
|
|
|
|
|
|
# POST /model
|
|
def handler(raw_event, context):
|
|
logger.info(json.dumps(raw_event))
|
|
request_id = context.aws_request_id
|
|
event = Event(**json.loads(raw_event['body']))
|
|
_type = event.model_type
|
|
|
|
try:
|
|
# check if roles has already linked to an endpoint?
|
|
creator_permissions = get_permissions_by_username(ddb_service, user_table, event.creator)
|
|
if 'train' not in creator_permissions \
|
|
or ('all' not in creator_permissions['train'] and 'create' not in creator_permissions['train']):
|
|
return bad_request(message=f'user {event.creator} has not permission to create a train job')
|
|
|
|
user_roles = get_user_roles(ddb_service, user_table, event.creator)
|
|
|
|
# todo: check if duplicated name and new_model_name only for Completed and Model
|
|
if not event.checkpoint_id and len(event.filenames) == 0:
|
|
return bad_request(message='either checkpoint_id or filenames need to be provided')
|
|
|
|
base_key = get_base_model_s3_key(_type, event.name, request_id)
|
|
timestamp = datetime.datetime.now().timestamp()
|
|
multiparts_resp = {}
|
|
if not event.checkpoint_id:
|
|
checkpoint_base_key = get_base_checkpoint_s3_key(_type, event.name, request_id)
|
|
presign_url_map = batch_get_s3_multipart_signed_urls(
|
|
bucket_name=bucket_name,
|
|
base_key=checkpoint_base_key,
|
|
filenames=event.filenames
|
|
)
|
|
filenames_only = []
|
|
for f in event.filenames:
|
|
file = MultipartFileReq(**f)
|
|
filenames_only.append(file.filename)
|
|
|
|
checkpoint_params = {'created': str(datetime.datetime.now()), 'multipart_upload': {
|
|
}}
|
|
|
|
for key, val in presign_url_map.items():
|
|
checkpoint_params['multipart_upload'][key] = {
|
|
'upload_id': val['upload_id'],
|
|
'bucket': val['bucket'],
|
|
'key': val['key'],
|
|
}
|
|
multiparts_resp[key] = val['s3_signed_urls']
|
|
|
|
checkpoint = CheckPoint(
|
|
id=request_id,
|
|
checkpoint_type=event.model_type,
|
|
s3_location=f's3://{bucket_name}/{get_base_checkpoint_s3_key(_type, event.name, request_id)}',
|
|
checkpoint_names=filenames_only,
|
|
checkpoint_status=CheckPointStatus.Initial,
|
|
params=checkpoint_params,
|
|
timestamp=timestamp,
|
|
allowed_roles_or_users=user_roles,
|
|
)
|
|
ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__)
|
|
checkpoint_id = checkpoint.id
|
|
else:
|
|
raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={
|
|
'id': event.checkpoint_id,
|
|
})
|
|
if raw_checkpoint is None:
|
|
return bad_request(message=f'create model ckpt with id {event.checkpoint_id} is not found')
|
|
|
|
checkpoint = CheckPoint(**raw_checkpoint)
|
|
if checkpoint.checkpoint_status != CheckPointStatus.Active:
|
|
return bad_request(message=f'checkpoint with id ({checkpoint.id}) is not Active to use')
|
|
checkpoint_id = checkpoint.id
|
|
if checkpoint.allowed_roles_or_users:
|
|
allowed_to_use = False
|
|
for role in checkpoint.allowed_roles_or_users:
|
|
if role in user_roles or '*' == role:
|
|
allowed_to_use = True
|
|
break
|
|
|
|
if not allowed_to_use:
|
|
return bad_request(
|
|
message=f'checkpoint with id ({checkpoint.id}) is not allowed to use by user {event.creator}')
|
|
|
|
model_job = Model(
|
|
id=request_id,
|
|
name=event.name,
|
|
output_s3_location=f's3://{bucket_name}/{base_key}/output',
|
|
checkpoint_id=checkpoint_id,
|
|
model_type=_type,
|
|
job_status=CreateModelStatus.Initial,
|
|
params=event.params,
|
|
timestamp=timestamp,
|
|
allowed_roles_or_users=[event.creator]
|
|
)
|
|
ddb_service.put_items(table=model_table, entries=model_job.__dict__)
|
|
|
|
except ClientError as e:
|
|
logger.error(e)
|
|
return internal_server_error(message=str(e))
|
|
|
|
data = {
|
|
'job': {
|
|
'id': model_job.id,
|
|
'status': model_job.job_status.value,
|
|
's3_base': checkpoint.s3_location,
|
|
'model_type': model_job.model_type,
|
|
'params': model_job.params # not safe if not json serializable type
|
|
},
|
|
's3PresignUrl': multiparts_resp
|
|
}
|
|
|
|
return created(data=data)
|