From 9a5a80fa8accc91bc4faedd5fc322064eccaf15e Mon Sep 17 00:00:00 2001 From: Jingyi Date: Tue, 19 Dec 2023 22:27:30 +0800 Subject: [PATCH 1/4] update create user api --- .../cloud_api_manager/api_manager.py | 6 +- aws_extension/sagemaker_ui_tab.py | 4 +- .../src/sd-users/multi-users-stack.ts | 4 +- .../src/sd-users/user-upsert-api.ts | 74 +++++++++++++++---- .../lambda/multi_users/multi_users_api.py | 57 ++++++-------- 5 files changed, 89 insertions(+), 56 deletions(-) diff --git a/aws_extension/cloud_api_manager/api_manager.py b/aws_extension/cloud_api_manager/api_manager.py index 806bbb3f..1aeb3276 100644 --- a/aws_extension/cloud_api_manager/api_manager.py +++ b/aws_extension/cloud_api_manager/api_manager.py @@ -229,14 +229,14 @@ class CloudApiManager: if initial: cloud_auth_manager.refresh() - raw_resp = requests.post(f'{cloud_auth_manager.api_url}user', + raw_resp = requests.post(f'{cloud_auth_manager.api_url}users', json=payload, headers=self._get_headers_by_user(user_token) ) raw_resp.raise_for_status() resp = raw_resp.json() - if resp['statusCode'] != 200: - raise Exception(resp['errMsg']) + if raw_resp.status_code != 200: + raise Exception(resp['message']) cloud_auth_manager.update_gradio_auth() return True diff --git a/aws_extension/sagemaker_ui_tab.py b/aws_extension/sagemaker_ui_tab.py index b5b53881..8cb8b3a5 100644 --- a/aws_extension/sagemaker_ui_tab.py +++ b/aws_extension/sagemaker_ui_tab.py @@ -335,7 +335,7 @@ def role_settings_tab(): if resp: return f'Role upsert complete "{role_name}"' except Exception as e: - return f'User upsert failed: {e}' + return f'Role upsert failed: {e}' upsert_role_button.click(fn=upsert_role, inputs=[rolename_textbox, permissions_dropdown], @@ -892,7 +892,7 @@ def update_connect_config(api_url, api_token, username=None, password=None, init initial=initial, user_token=username): return 'Initial Setup Failed' except Exception as e: - return f'User upsert failed: {e}' + return f'Initial Setup failed: {e}' return "Setting updated" diff --git a/infrastructure/src/sd-users/multi-users-stack.ts b/infrastructure/src/sd-users/multi-users-stack.ts index b6256e6d..6b6e0899 100644 --- a/infrastructure/src/sd-users/multi-users-stack.ts +++ b/infrastructure/src/sd-users/multi-users-stack.ts @@ -48,12 +48,12 @@ export class MultiUsersStack extends NestedStack { authorizer: props.authorizer, }); - new UserUpsertApi(scope, 'userUpsert', { + new UserUpsertApi(scope, 'CreateUser', { commonLayer: props.commonLayer, httpMethod: 'POST', multiUserTable: props.multiUserTable, passwordKey: props.passwordKeyAlias, - router: props.routers.user, + router: props.routers.users, srcRoot: this.srcRoot, authorizer: props.authorizer, }); diff --git a/infrastructure/src/sd-users/user-upsert-api.ts b/infrastructure/src/sd-users/user-upsert-api.ts index fa5fe32a..d7ea1621 100644 --- a/infrastructure/src/sd-users/user-upsert-api.ts +++ b/infrastructure/src/sd-users/user-upsert-api.ts @@ -11,6 +11,7 @@ import { MethodOptions } from 'aws-cdk-lib/aws-apigateway/lib/method'; import { Effect } from 'aws-cdk-lib/aws-iam'; import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda'; import { Construct } from 'constructs'; +import {JsonSchemaType, JsonSchemaVersion, Model, RequestValidator} from "aws-cdk-lib/aws-apigateway"; export interface UserUpsertApiProps { router: aws_apigateway.Resource; @@ -95,7 +96,6 @@ export class UserUpsertApi { private upsertUserApi() { const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, { - functionName: `${this.baseId}-upsert`, entry: `${this.src}/multi_users`, architecture: Architecture.X86_64, runtime: Runtime.PYTHON_3_9, @@ -110,28 +110,72 @@ export class UserUpsertApi { }, layers: [this.layer], }); + + const requestModel = new Model(this.scope, `${this.baseId}-model`,{ + restApi: this.router.api, + modelName: this.baseId, + description: `${this.baseId} Request Model`, + schema: { + schema: JsonSchemaVersion.DRAFT4, + title: this.baseId, + type: JsonSchemaType.OBJECT, + properties: { + username: { + type: JsonSchemaType.STRING, + minLength: 1, + }, + password: { + type: JsonSchemaType.STRING, + minLength: 1, + }, + creator: { + type: JsonSchemaType.STRING, + minLength: 1, + }, + initial: { + type: JsonSchemaType.BOOLEAN, + default: false, + }, + roles: { + type: JsonSchemaType.ARRAY, + items: { + type: JsonSchemaType.STRING, + minLength: 1, + }, + minItems: 1, + maxItems: 20, + }, + }, + required: [ + 'username', + 'creator', + ], + }, + contentType: 'application/json', + }); + + const requestValidator = new RequestValidator( + this.scope, + `${this.baseId}-validator`, + { + restApi: this.router.api, + requestValidatorName: this.baseId, + validateRequestBody: true, + }); + const upsertUserIntegration = new apigw.LambdaIntegration( lambdaFunction, { - proxy: false, - requestTemplates: { - 'application/json': '{\n' + - ' "body": $input.json("$"),' + - ' "x-auth": {\n' + - ' "username": "$context.authorizer.username",\n' + - ' "role": "$context.authorizer.role"\n' + - ' }\n' + - '}', - }, - integrationResponses: [{ statusCode: '200' }], + proxy: true, }, ); this.router.addMethod(this.httpMethod, upsertUserIntegration, { apiKeyRequired: true, authorizer: this.authorizer, - methodResponses: [{ - statusCode: '200', - }], + requestValidator, + requestModels: { + 'application/json': requestModel, + } }); } } diff --git a/middleware_api/lambda/multi_users/multi_users_api.py b/middleware_api/lambda/multi_users/multi_users_api.py index df43bb81..9e44690d 100644 --- a/middleware_api/lambda/multi_users/multi_users_api.py +++ b/middleware_api/lambda/multi_users/multi_users_api.py @@ -6,7 +6,7 @@ from typing import List, Optional from common.ddb_service.client import DynamoDbUtilsService from _types import User, PARTITION_KEYS, Role, Default_Role -from common.response import ok +from common.response import ok, bad_request from roles_api import upsert_role from utils import KeyEncryptService, check_user_existence, get_permissions_by_username, get_user_by_username @@ -31,8 +31,8 @@ class UpsertUserEvent: # POST /user def upsert_user(raw_event, ctx): - print(raw_event) - event = UpsertUserEvent(**raw_event['body']) + logger.info(json.dumps(raw_event)) + event = UpsertUserEvent(**json.loads(raw_event['body'])) if event.initial: rolenames = [Default_Role] @@ -63,13 +63,16 @@ def upsert_user(raw_event, ctx): aws_request_id: str from_sd_local: bool - resp = upsert_role(role_event, MockContext(aws_request_id='', from_sd_local=True)) + # todo will be remove, not use api + create_role_event = { + 'body': json.dumps(role_event) + } + resp = upsert_role(create_role_event, MockContext(aws_request_id='', from_sd_local=True)) if resp['statusCode'] != 200: return resp - return { - 'statusCode': 200, + data = { 'user': { 'username': event.username, 'roles': [rolenames[0]] @@ -77,6 +80,8 @@ def upsert_user(raw_event, ctx): 'all_roles': rolenames, } + return ok(data=data) + check_permission_resp = _check_action_permission(event.creator, event.username) if check_permission_resp: return check_permission_resp @@ -98,19 +103,13 @@ def upsert_user(raw_event, ctx): resource = permission_parts[0] action = permission_parts[1] if 'all' not in creator_permissions[resource] and action not in creator_permissions[resource]: - return { - 'statusCode': 400, - 'errMsg': f'creator has no permission to assign permission [{permission}] to others' - } + return bad_request(message=f'creator has no permission to assign permission [{permission}] to others') roles_pool.append(role.sort_key) for role in event.roles: if role not in roles_pool: - return { - 'statusCode': 400, - 'errMsg': f'user roles "{role}" not exist' - } + return bad_request(message=f'user roles "{role}" not exist') ddb_service.put_items(user_table, User( kind=PARTITION_KEYS.user, @@ -120,8 +119,7 @@ def upsert_user(raw_event, ctx): creator=event.creator, ).__dict__) - return { - 'statusCode': 200, + data = { 'user': { 'username': event.username, 'roles': event.roles, @@ -129,6 +127,8 @@ def upsert_user(raw_event, ctx): } } + return ok(data=data) + # DELETE /user/{username} def delete_user(event, ctx): @@ -155,10 +155,7 @@ def delete_user(event, ctx): def _check_action_permission(creator_username, target_username): # check if creator exist if check_user_existence(ddb_service=ddb_service, user_table=user_table, username=creator_username): - return { - 'statusCode': 400, - 'errMsg': f'creator {creator_username} not exist' - } + return bad_request(message=f'creator {creator_username} not exist') target_user = get_user_by_username(ddb_service, user_table, target_username) @@ -166,28 +163,20 @@ def _check_action_permission(creator_username, target_username): if 'user' not in creator_permissions or \ ('all' not in creator_permissions['user'] and 'create' not in creator_permissions['user']): - return { - 'statusCode': 400, - 'errMsg': f'creator {creator_username} does not have permission to manage the user' - } + return bad_request(message=f'creator {creator_username} does not have permission to manage the user') # if the creator have no permission (not created by creator), # make sure the creator doesn't change the existed user (created by others) # and only user with 'user:all' can do update any users if target_user and target_user.creator != creator_username and 'all' not in creator_permissions['user']: - return { - 'statusCode': 400, - 'errMsg': f'username {target_user.sort_key} has already exists, ' - f'creator {creator_username} does not have permissions to change it' - } + return bad_request(message=f'username {target_user.sort_key} has already exists, ' + f'creator {creator_username} does not have permissions to change it') if target_user and target_user.creator == creator_username and 'create' not in creator_permissions['user'] and 'all' \ not in creator_permissions['user']: - return { - 'statusCode': 400, - 'errMsg': f'username {target_user.sort_key} has already exists, ' - f'creator {creator_username} does not have permissions to change it' - } + return bad_request( + message=f'username {target_user.sort_key} has already exists, ' + f'creator {creator_username} does not have permissions to change it') return None From b83b6f7f77a0ce3a3b23bc098b9d4a97bc39f5f6 Mon Sep 17 00:00:00 2001 From: Jingyi Date: Tue, 19 Dec 2023 22:56:33 +0800 Subject: [PATCH 2/4] update checkpoint apis --- aws_extension/sagemaker_ui.py | 7 +- .../src/sd-train/chekpoint-create-api.ts | 88 +++++++++++++++---- .../src/sd-train/chekpoint-update-api.ts | 70 ++++++++++----- .../src/sd-train/sd-train-deploy-stack.ts | 8 +- .../lambda/model_and_train/checkpoint_api.py | 53 ++++------- 5 files changed, 140 insertions(+), 86 deletions(-) diff --git a/aws_extension/sagemaker_ui.py b/aws_extension/sagemaker_ui.py index 5bc2cdcb..0f424f39 100644 --- a/aws_extension/sagemaker_ui.py +++ b/aws_extension/sagemaker_ui.py @@ -534,7 +534,7 @@ def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_ api_key = get_variable_from_json('api_token') logger.info(f'!!!!!!api_gateway_url {api_gateway_url}') - url = str(api_gateway_url) + "checkpoint" + url = str(api_gateway_url) + "checkpoints" logger.debug(f"Post request for upload s3 presign url: {url}") @@ -542,7 +542,7 @@ def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_ try: response.raise_for_status() - json_response = response.json() + json_response = response.json()['data'] logger.debug(f"Response json {json_response}") s3_base = json_response["checkpoint"]["s3_location"] checkpoint_id = json_response["checkpoint"]["id"] @@ -577,12 +577,11 @@ def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_ ) payload = { - "checkpoint_id": checkpoint_id, "status": "Active", "multi_parts_tags": {local_tar_path: multiparts_tags} } # Start creating model on cloud. - response = requests.put(url=url, json=payload, headers={'x-api-key': api_key}) + response = requests.put(url=f"{url}/{checkpoint_id}", json=payload, headers={'x-api-key': api_key}) s3_input_path = s3_base logger.debug(response) diff --git a/infrastructure/src/sd-train/chekpoint-create-api.ts b/infrastructure/src/sd-train/chekpoint-create-api.ts index 8a5b8167..4395019d 100644 --- a/infrastructure/src/sd-train/chekpoint-create-api.ts +++ b/infrastructure/src/sd-train/chekpoint-create-api.ts @@ -11,6 +11,7 @@ import { MethodOptions } from 'aws-cdk-lib/aws-apigateway/lib/method'; import { Effect } from 'aws-cdk-lib/aws-iam'; import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda'; import { Construct } from 'constructs'; +import {JsonSchemaType, JsonSchemaVersion, Model, RequestValidator} from "aws-cdk-lib/aws-apigateway"; export interface CreateCheckPointApiProps { @@ -101,8 +102,7 @@ export class CreateCheckPointApi { } private createCheckpointApi() { - const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-create`, { - functionName: `${this.baseId}-create-checkpoint`, + const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, { entry: `${this.src}/model_and_train`, architecture: Architecture.X86_64, runtime: Runtime.PYTHON_3_9, @@ -118,30 +118,80 @@ export class CreateCheckPointApi { }, layers: [this.layer], }); + + const requestModel = new Model(this.scope, `${this.baseId}-model`,{ + restApi: this.router.api, + modelName: this.baseId, + description: `${this.baseId} Request Model`, + schema: { + schema: JsonSchemaVersion.DRAFT4, + title: this.baseId, + type: JsonSchemaType.OBJECT, + properties: { + checkpoint_type: { + type: JsonSchemaType.STRING, + minLength: 1, + }, + filenames: { + type: JsonSchemaType.ARRAY, + items: { + type: JsonSchemaType.OBJECT, + properties: { + filename: { + type: JsonSchemaType.STRING, + minLength: 1, + }, + parts_number: { + type: JsonSchemaType.INTEGER, + minimum: 1, + maximum: 100, + }, + }, + }, + minItems: 1, + maxItems: 20, + }, + params: { + type: JsonSchemaType.OBJECT, + properties: { + message: { + type: JsonSchemaType.STRING, + }, + creator: { + type: JsonSchemaType.STRING, + } + }, + }, + }, + required: [ + 'checkpoint_type', + 'filenames', + ], + }, + contentType: 'application/json', + }); + + const requestValidator = new RequestValidator( + this.scope, + `${this.baseId}-validator`, + { + restApi: this.router.api, + requestValidatorName: this.baseId, + validateRequestBody: true, + }); + const createCheckpointIntegration = new apigw.LambdaIntegration( lambdaFunction, { - proxy: false, - integrationResponses: [{ - statusCode: '200', - responseParameters: { - 'method.response.header.Access-Control-Allow-Headers': "'Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token,X-Amz-User-Agent'", - 'method.response.header.Access-Control-Allow-Methods': "'GET,POST,PUT,OPTIONS'", - 'method.response.header.Access-Control-Allow-Origin': "'*'", - }, - }], + proxy: true, }, ); this.router.addMethod(this.httpMethod, createCheckpointIntegration, { apiKeyRequired: true, - methodResponses: [{ - statusCode: '200', - responseParameters: { - 'method.response.header.Access-Control-Allow-Headers': true, - 'method.response.header.Access-Control-Allow-Methods': true, - 'method.response.header.Access-Control-Allow-Origin': true, - }, - }], + requestValidator, + requestModels: { + 'application/json': requestModel, + } }); } } diff --git a/infrastructure/src/sd-train/chekpoint-update-api.ts b/infrastructure/src/sd-train/chekpoint-update-api.ts index bb0524d9..810450f9 100644 --- a/infrastructure/src/sd-train/chekpoint-update-api.ts +++ b/infrastructure/src/sd-train/chekpoint-update-api.ts @@ -7,10 +7,10 @@ import { aws_lambda, aws_s3, Duration, } from 'aws-cdk-lib'; -import { MethodOptions } from 'aws-cdk-lib/aws-apigateway/lib/method'; import { Effect } from 'aws-cdk-lib/aws-iam'; import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda'; import { Construct } from 'constructs'; +import {JsonSchemaType, JsonSchemaVersion, Model, RequestValidator} from "aws-cdk-lib/aws-apigateway"; export interface UpdateCheckPointApiProps { @@ -93,8 +93,7 @@ export class UpdateCheckPointApi { } private updateCheckpointApi() { - const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-update`, { - functionName: `${this.baseId}-update-checkpoint`, + const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, { entry: `${this.src}/model_and_train`, architecture: Architecture.X86_64, runtime: Runtime.PYTHON_3_9, @@ -109,31 +108,56 @@ export class UpdateCheckPointApi { }, layers: [this.layer], }); + + const requestModel = new Model(this.scope, `${this.baseId}-model`,{ + restApi: this.router.api, + modelName: this.baseId, + description: `${this.baseId} Request Model`, + schema: { + schema: JsonSchemaVersion.DRAFT4, + title: this.baseId, + type: JsonSchemaType.OBJECT, + properties: { + status: { + type: JsonSchemaType.STRING, + minLength: 1, + }, + multi_parts_tags: { + type: JsonSchemaType.OBJECT, + }, + }, + required: [ + 'status', + 'multi_parts_tags', + ], + }, + contentType: 'application/json', + }); + + const requestValidator = new RequestValidator( + this.scope, + `${this.baseId}-validator`, + { + restApi: this.router.api, + requestValidatorName: this.baseId, + validateRequestBody: true, + }); + const createModelIntegration = new apigw.LambdaIntegration( lambdaFunction, { - proxy: false, - integrationResponses: [{ - statusCode: '200', - responseParameters: { - 'method.response.header.Access-Control-Allow-Headers': "'Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token,X-Amz-User-Agent'", - 'method.response.header.Access-Control-Allow-Methods': "'GET,POST,PUT,OPTIONS'", - 'method.response.header.Access-Control-Allow-Origin': "'*'", - }, - }], + proxy: true, }, ); - this.router.addMethod(this.httpMethod, createModelIntegration, { - apiKeyRequired: true, - methodResponses: [{ - statusCode: '200', - responseParameters: { - 'method.response.header.Access-Control-Allow-Headers': true, - 'method.response.header.Access-Control-Allow-Methods': true, - 'method.response.header.Access-Control-Allow-Origin': true, - }, - }], - }); + this.router.addResource('{id}') + .addMethod(this.httpMethod, createModelIntegration, + { + apiKeyRequired: true, + requestValidator, + requestModels: { + 'application/json': requestModel, + } + }); } } diff --git a/infrastructure/src/sd-train/sd-train-deploy-stack.ts b/infrastructure/src/sd-train/sd-train-deploy-stack.ts index 294890f9..cfcb8403 100644 --- a/infrastructure/src/sd-train/sd-train-deploy-stack.ts +++ b/infrastructure/src/sd-train/sd-train-deploy-stack.ts @@ -172,22 +172,22 @@ export class SdTrainDeployStack extends NestedStack { // POST /checkpoint - new CreateCheckPointApi(this, 'sdExtn-createCkpt', { + new CreateCheckPointApi(this, 'CreateCheckPoint', { checkpointTable: checkPointTable, commonLayer: commonLayer, httpMethod: 'POST', - router: routers.checkpoint, + router: routers.checkpoints, s3Bucket: s3Bucket, srcRoot: this.srcRoot, multiUserTable: multiUserTable, }); // PUT /checkpoint - new UpdateCheckPointApi(this, 'sdExtn-updateCkpt', { + new UpdateCheckPointApi(this, 'UpdateCheckPoint', { checkpointTable: checkPointTable, commonLayer: commonLayer, httpMethod: 'PUT', - router: routers.checkpoint, + router: routers.checkpoints, s3Bucket: s3Bucket, srcRoot: this.srcRoot, }); diff --git a/middleware_api/lambda/model_and_train/checkpoint_api.py b/middleware_api/lambda/model_and_train/checkpoint_api.py index 511ae295..7ccf16b9 100644 --- a/middleware_api/lambda/model_and_train/checkpoint_api.py +++ b/middleware_api/lambda/model_and_train/checkpoint_api.py @@ -10,7 +10,7 @@ import json from _types import CheckPoint, CheckPointStatus, MultipartFileReq from common.ddb_service.client import DynamoDbUtilsService -from common.response import ok +from common.response import ok, bad_request, internal_server_error from common_tools import get_base_checkpoint_s3_key, \ batch_get_s3_multipart_signed_urls, complete_multipart_upload, multipart_upload_from_url from multi_users._types import PARTITION_KEYS, Role @@ -226,7 +226,7 @@ class CreateCheckPointEvent: # POST /checkpoint def create_checkpoint_api(raw_event, context): request_id = context.aws_request_id - event = CreateCheckPointEvent(**raw_event) + event = CreateCheckPointEvent(**json.loads(raw_event['body'])) _type = event.checkpoint_type headers = { 'Access-Control-Allow-Headers': 'Content-Type', @@ -262,11 +262,7 @@ def create_checkpoint_api(raw_event, context): filenames_only.append(file.filename) if len(filenames_only) == 0: - return { - 'statusCode': 400, - 'headers': headers, - 'errorMsg': 'no checkpoint name (file names) detected' - } + return bad_request(message='no checkpoint name (file names) detected', headers=headers) user_roles = ['*'] creator_permissions = {} @@ -276,11 +272,7 @@ def create_checkpoint_api(raw_event, context): if 'checkpoint' not in creator_permissions or \ ('all' not in creator_permissions['checkpoint'] and 'create' not in creator_permissions['checkpoint']): - return { - 'statusCode': 400, - 'headers': headers, - 'error': f"user has no permissions to create a model" - } + return bad_request(message='user has no permissions to create a model', headers=headers) checkpoint = CheckPoint( id=request_id, @@ -293,9 +285,7 @@ def create_checkpoint_api(raw_event, context): allowed_roles_or_users=user_roles ) ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__) - return { - 'statusCode': 200, - 'headers': headers, + data = { 'checkpoint': { 'id': request_id, 'type': _type, @@ -305,25 +295,22 @@ def create_checkpoint_api(raw_event, context): }, 's3PresignUrl': multiparts_resp } + return ok(data=data, headers=headers) except Exception as e: logger.error(e) - return { - 'statusCode': 500, - 'headers': headers, - 'error': str(e) - } + return internal_server_error(headers=headers, message=str(e)) @dataclass class UpdateCheckPointEvent: - checkpoint_id: str status: str multi_parts_tags: Dict[str, Any] # PUT /checkpoint def update_checkpoint_api(raw_event, context): - event = UpdateCheckPointEvent(**raw_event) + event = UpdateCheckPointEvent(**json.loads(raw_event['body'])) + checkpoint_id = raw_event['pathParameters']['id'] headers = { 'Access-Control-Allow-Headers': 'Content-Type', 'Access-Control-Allow-Origin': '*', @@ -331,14 +318,13 @@ def update_checkpoint_api(raw_event, context): } try: raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={ - 'id': event.checkpoint_id + 'id': checkpoint_id }) if raw_checkpoint is None or len(raw_checkpoint) == 0: - return { - 'statusCode': 500, - 'headers': headers, - 'error': f'checkpoint not found with id {event.checkpoint_id}' - } + return bad_request( + message=f'checkpoint not found with id {checkpoint_id}', + headers=headers + ) checkpoint = CheckPoint(**raw_checkpoint) new_status = CheckPointStatus[event.status] @@ -352,9 +338,7 @@ def update_checkpoint_api(raw_event, context): field_name='checkpoint_status', value=new_status ) - return { - 'statusCode': 200, - 'headers': headers, + data = { 'checkpoint': { 'id': checkpoint.id, 'type': checkpoint.checkpoint_type, @@ -363,10 +347,7 @@ def update_checkpoint_api(raw_event, context): 'params': checkpoint.params } } + return ok(data=data, headers=headers) except Exception as e: logger.error(e) - return { - 'statusCode': 500, - 'headers': headers, - 'msg': str(e) - } + return internal_server_error(headers=headers, message=str(e)) From 1ab6aa7a2fb09a1e0a335e69d483f8027e8a4528 Mon Sep 17 00:00:00 2001 From: Jingyi Date: Tue, 19 Dec 2023 23:47:44 +0800 Subject: [PATCH 3/4] update endpoints apis --- .../sagemaker-endpoints-create.ts | 21 +++++---------- .../sagemaker-endpoints-delete.ts | 15 +++++------ .../sd-inference/sd-async-inference-stack.ts | 4 +-- .../src/sd-users/multi-users-stack.ts | 2 +- .../inference_v2/sagemaker_endpoint_api.py | 27 +++++++------------ 5 files changed, 26 insertions(+), 43 deletions(-) diff --git a/infrastructure/src/sd-inference/sagemaker-endpoints-create.ts b/infrastructure/src/sd-inference/sagemaker-endpoints-create.ts index a36938e7..285ae0c5 100644 --- a/infrastructure/src/sd-inference/sagemaker-endpoints-create.ts +++ b/infrastructure/src/sd-inference/sagemaker-endpoints-create.ts @@ -192,7 +192,6 @@ export class CreateSagemakerEndpointsApi { const role = this.iamRole(); const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, { - functionName: `${this.baseId}-api`, entry: `${this.src}/inference_v2`, architecture: Architecture.X86_64, runtime: Runtime.PYTHON_3_9, @@ -213,14 +212,14 @@ export class CreateSagemakerEndpointsApi { layers: [this.layer], }); - const model = new Model(this.scope, 'CreateEndpointModel', { + const model = new Model(this.scope, `${this.baseId}-model`, { restApi: this.router.api, contentType: 'application/json', - modelName: 'CreateEndpointModel', - description: 'Create Endpoint Model', + modelName: this.baseId, + description: `${this.baseId} Request Model`, schema: { schema: JsonSchemaVersion.DRAFT4, - title: 'createEndpointSchema', + title: this.baseId, type: JsonSchemaType.OBJECT, properties: { endpoint_name: { @@ -262,14 +261,13 @@ export class CreateSagemakerEndpointsApi { const integration = new LambdaIntegration( lambdaFunction, { - proxy: false, - integrationResponses: [{ statusCode: '200' }], + proxy: true, }, ); - const requestValidator = new RequestValidator(this.scope, 'CreateEndpointRequestValidator', { + const requestValidator = new RequestValidator(this.scope, `${this.baseId}-validator`, { restApi: this.router.api, - requestValidatorName: 'CreateEndpointRequestValidator', + requestValidatorName: this.baseId, validateRequestBody: true, validateRequestParameters: false, }); @@ -281,11 +279,6 @@ export class CreateSagemakerEndpointsApi { requestModels: { 'application/json': model, }, - methodResponses: [ - { - statusCode: '200', - }, { statusCode: '500' }, - ], }); } diff --git a/infrastructure/src/sd-inference/sagemaker-endpoints-delete.ts b/infrastructure/src/sd-inference/sagemaker-endpoints-delete.ts index 60c95ae0..6c36cad2 100644 --- a/infrastructure/src/sd-inference/sagemaker-endpoints-delete.ts +++ b/infrastructure/src/sd-inference/sagemaker-endpoints-delete.ts @@ -111,8 +111,7 @@ export class DeleteSagemakerEndpointsApi { } private deleteEndpointsApi() { - const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-delete-endpoints`, { - functionName: `${this.baseId}-function`, + const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, { entry: `${this.src}/inference_v2`, architecture: Architecture.X86_64, runtime: Runtime.PYTHON_3_9, @@ -128,13 +127,13 @@ export class DeleteSagemakerEndpointsApi { layers: [this.layer], }); - const model = new Model(this.scope, 'DeleteEndpointsModel', { + const model = new Model(this.scope, `${this.baseId}-model`, { restApi: this.router.api, - modelName: 'DeleteEndpointsModel', - description: 'Delete Endpoint Model', + modelName: this.baseId, + description: `${this.baseId} Request Model`, schema: { schema: JsonSchemaVersion.DRAFT4, - title: 'deleteEndpointSchema', + title: this.baseId, type: JsonSchemaType.OBJECT, properties: { endpoint_name_list: { @@ -165,9 +164,9 @@ export class DeleteSagemakerEndpointsApi { }, ); - const requestValidator = new RequestValidator(this.scope, 'DeleteEndpointRequestValidator', { + const requestValidator = new RequestValidator(this.scope, `${this.baseId}-validator`, { restApi: this.router.api, - requestValidatorName: 'DeleteEndpointRequestValidator', + requestValidatorName: this.baseId, validateRequestBody: true, }); diff --git a/infrastructure/src/sd-inference/sd-async-inference-stack.ts b/infrastructure/src/sd-inference/sd-async-inference-stack.ts index 737ddf05..3921fe99 100644 --- a/infrastructure/src/sd-inference/sd-async-inference-stack.ts +++ b/infrastructure/src/sd-inference/sd-async-inference-stack.ts @@ -122,7 +122,7 @@ export class SDAsyncInferenceStack extends NestedStack { ); new DeleteSagemakerEndpointsApi( - this, 'sd-infer-v2-deleteEndpoints', + this, 'DeleteEndpoints', { router: props.routers.endpoints, commonLayer: props.commonLayer, @@ -164,7 +164,7 @@ export class SDAsyncInferenceStack extends NestedStack { const inference_result_error_topic = aws_sns.Topic.fromTopicArn(scope, `${id}-infer-result-err-tp`, props.inferenceErrorTopic.topicArn); new CreateSagemakerEndpointsApi( - this, 'sd-infer-v2-createEndpoint', + this, 'CreateEndpoint', { router: props.routers.endpoints, commonLayer: props.commonLayer, diff --git a/infrastructure/src/sd-users/multi-users-stack.ts b/infrastructure/src/sd-users/multi-users-stack.ts index 6b6e0899..294f63c6 100644 --- a/infrastructure/src/sd-users/multi-users-stack.ts +++ b/infrastructure/src/sd-users/multi-users-stack.ts @@ -31,7 +31,7 @@ export class MultiUsersStack extends NestedStack { constructor(scope: Construct, id: string, props: MultiUsersStackProps) { super(scope, id, props); - new RoleUpsertApi(scope, 'roleUpsert', { + new RoleUpsertApi(scope, 'CreateRole', { commonLayer: props.commonLayer, httpMethod: 'POST', multiUserTable: props.multiUserTable, diff --git a/middleware_api/lambda/inference_v2/sagemaker_endpoint_api.py b/middleware_api/lambda/inference_v2/sagemaker_endpoint_api.py index 470582a1..12beadb3 100644 --- a/middleware_api/lambda/inference_v2/sagemaker_endpoint_api.py +++ b/middleware_api/lambda/inference_v2/sagemaker_endpoint_api.py @@ -147,7 +147,7 @@ class CreateEndpointEvent: def sagemaker_endpoint_create_api(raw_event, ctx): logger.info(f"Received event: {raw_event}") logger.info(f"Received ctx: {ctx}") - event = CreateEndpointEvent(**raw_event) + event = CreateEndpointEvent(**json.loads(raw_event['body'])) try: endpoint_deployment_id = str(uuid.uuid4()) @@ -172,10 +172,7 @@ def sagemaker_endpoint_create_api(raw_event, ctx): creator_permissions = get_permissions_by_username(ddb_service, user_table, event.creator) if 'sagemaker_endpoint' not in creator_permissions or \ ('all' not in creator_permissions['sagemaker_endpoint'] and 'create' not in creator_permissions['sagemaker_endpoint']): - return { - 'statusCode': 400, - 'message': f"Creator {event.creator} has no permission to create Sagemaker", - } + return bad_request(message=f"Creator {event.creator} has no permission to create Sagemaker") endpoint_rows = ddb_service.scan(sagemaker_endpoint_table, filters=None) for endpoint_row in endpoint_rows: @@ -184,10 +181,8 @@ def sagemaker_endpoint_create_api(raw_event, ctx): if endpoint.endpoint_status != EndpointStatus.DELETED.value and endpoint.status != 'deleted': for role in event.assign_to_roles: if role in endpoint.owner_group_or_role: - return { - 'statusCode': 400, - 'message': f"role [{role}] has a valid endpoint already, not allow to have another one", - } + return bad_request( + message=f"role [{role}] has a valid endpoint already, not allow to have another one") _create_sagemaker_model(sagemaker_model_name, image_url, model_data_url) @@ -216,17 +211,13 @@ def sagemaker_endpoint_create_api(raw_event, ctx): ddb_service.put_items(table=sagemaker_endpoint_table, entries=raw.__dict__) logger.info(f"Successfully created endpoint deployment: {raw.__dict__}") - return { - 'statusCode': 200, - 'message': f"Endpoint deployment started: {sagemaker_endpoint_name}", - 'data': raw.__dict__ - } + return ok( + message=f"Endpoint deployment started: {sagemaker_endpoint_name}", + data=raw.__dict__ + ) except Exception as e: logger.error(e) - return { - 'statusCode': 200, - 'message': str(e), - } + return ok(message=str(e)) # lambda: handle sagemaker events From c41534d494c1ffbe72dcd3b287a2fa5539b31f2a Mon Sep 17 00:00:00 2001 From: Jingyi Date: Wed, 20 Dec 2023 11:03:02 +0800 Subject: [PATCH 4/4] update datasets apis --- .../cloud_api_manager/api_manager.py | 4 +- aws_extension/sagemaker_ui_tab.py | 9 +- .../src/sd-train/dataset-create-api.ts | 84 +++++++++++++++++-- .../src/sd-train/dataset-update-api.ts | 54 +++++++++--- .../src/sd-train/datasets-item-listall-api.ts | 40 ++------- .../src/sd-train/sd-train-deploy-stack.ts | 12 +-- .../lambda/dataset_service/dataset_api.py | 83 ++++++------------ 7 files changed, 166 insertions(+), 120 deletions(-) diff --git a/aws_extension/cloud_api_manager/api_manager.py b/aws_extension/cloud_api_manager/api_manager.py index 1aeb3276..ac1f0a1b 100644 --- a/aws_extension/cloud_api_manager/api_manager.py +++ b/aws_extension/cloud_api_manager/api_manager.py @@ -304,10 +304,10 @@ class CloudApiManager: if not self.auth_manger.enableAuth: return [] - raw_response = requests.get(url=f'{self.auth_manger.api_url}dataset/{dataset_name}/data', headers=self._get_headers_by_user(user_token)) + raw_response = requests.get(url=f'{self.auth_manger.api_url}datasets/{dataset_name}', headers=self._get_headers_by_user(user_token)) raw_response.raise_for_status() # todo: the s3 presign url is not ready as content type to img - resp = raw_response.json() + resp = raw_response.json()['data'] return resp diff --git a/aws_extension/sagemaker_ui_tab.py b/aws_extension/sagemaker_ui_tab.py index 8cb8b3a5..a6e1b883 100644 --- a/aws_extension/sagemaker_ui_tab.py +++ b/aws_extension/sagemaker_ui_tab.py @@ -787,12 +787,14 @@ def dataset_tab(): "creator": pr.username } - url = get_variable_from_json('api_gateway_url') + '/dataset' + url = get_variable_from_json('api_gateway_url') + 'datasets' api_key = get_variable_from_json('api_token') raw_response = requests.post(url=url, json=payload, headers={'x-api-key': api_key}) + logger.info(raw_response.json()) + raw_response.raise_for_status() - response = raw_response.json() + response = raw_response.json()['data'] logger.info(f"Start upload sample files response:\n{response}") for filename, presign_url in response['s3PresignUrl'].items(): @@ -803,11 +805,10 @@ def dataset_tab(): response.raise_for_status() payload = { - "dataset_name": dataset_name, "status": "Enabled" } - raw_response = requests.put(url=url, json=payload, headers={'x-api-key': api_key}) + raw_response = requests.put(url=f"{url}/{dataset_name}", json=payload, headers={'x-api-key': api_key}) raw_response.raise_for_status() logger.debug(raw_response.json()) return f'Complete Dataset {dataset_name} creation', None, None, None, None diff --git a/infrastructure/src/sd-train/dataset-create-api.ts b/infrastructure/src/sd-train/dataset-create-api.ts index aac87a81..850a7f76 100644 --- a/infrastructure/src/sd-train/dataset-create-api.ts +++ b/infrastructure/src/sd-train/dataset-create-api.ts @@ -11,6 +11,7 @@ import { MethodOptions } from 'aws-cdk-lib/aws-apigateway/lib/method'; import { Effect } from 'aws-cdk-lib/aws-iam'; import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda'; import { Construct } from 'constructs'; +import {JsonSchemaType, JsonSchemaVersion, Model, RequestValidator} from "aws-cdk-lib/aws-apigateway"; export interface CreateDatasetApiProps { @@ -108,8 +109,7 @@ export class CreateDatasetApi { } private createDatasetApi() { - const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-create`, { - functionName: `${this.baseId}-create`, + const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, { entry: `${this.src}/dataset_service`, architecture: Architecture.X86_64, runtime: Runtime.PYTHON_3_9, @@ -126,18 +126,88 @@ export class CreateDatasetApi { }, layers: [this.layer], }); + + const requestModel = new Model(this.scope, `${this.baseId}-model`, { + restApi: this.router.api, + modelName: this.baseId, + description: `${this.baseId} Request Model`, + schema: { + schema: JsonSchemaVersion.DRAFT4, + title: this.baseId, + type: JsonSchemaType.OBJECT, + properties: { + dataset_name: { + type: JsonSchemaType.STRING, + minLength: 1, + }, + content: { + type: JsonSchemaType.ARRAY, + items: { + type: JsonSchemaType.OBJECT, + properties: { + filename: { + type: JsonSchemaType.STRING, + minLength: 1, + }, + name: { + type: JsonSchemaType.STRING, + minLength: 1, + }, + type: { + type: JsonSchemaType.STRING, + minLength: 1, + }, + params: { + type: JsonSchemaType.OBJECT, + }, + }, + }, + minItems: 1, + maxItems: 100, + }, + creator: { + type: JsonSchemaType.STRING, + minLength: 1, + }, + params: { + type: JsonSchemaType.OBJECT, + properties: { + description: { + type: JsonSchemaType.STRING, + }, + } + }, + }, + required: [ + 'dataset_name', + 'content', + 'creator', + ], + }, + contentType: 'application/json', + }); + + const requestValidator = new RequestValidator( + this.scope, + `${this.baseId}-validator`, + { + restApi: this.router.api, + requestValidatorName: this.baseId, + validateRequestBody: true, + }); + const createDatasetIntegration = new apigw.LambdaIntegration( lambdaFunction, { - proxy: false, - integrationResponses: [{ statusCode: '200' }], + proxy: true, }, ); this.router.addMethod(this.httpMethod, createDatasetIntegration, { apiKeyRequired: true, - methodResponses: [{ - statusCode: '200', - }], + requestValidator, + requestModels: { + 'application/json': requestModel, + } }); } } diff --git a/infrastructure/src/sd-train/dataset-update-api.ts b/infrastructure/src/sd-train/dataset-update-api.ts index 645b4522..431b39c9 100644 --- a/infrastructure/src/sd-train/dataset-update-api.ts +++ b/infrastructure/src/sd-train/dataset-update-api.ts @@ -11,6 +11,7 @@ import { MethodOptions } from 'aws-cdk-lib/aws-apigateway/lib/method'; import { Effect } from 'aws-cdk-lib/aws-iam'; import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda'; import { Construct } from 'constructs'; +import {JsonSchemaType, JsonSchemaVersion, Model, RequestValidator} from "aws-cdk-lib/aws-apigateway"; export interface UpdateDatasetApiProps { @@ -25,7 +26,7 @@ export interface UpdateDatasetApiProps { export class UpdateDatasetApi { private readonly src; - private readonly router: aws_apigateway.Resource; + public readonly router: aws_apigateway.Resource; private readonly httpMethod: string; private readonly scope: Construct; private readonly datasetInfoTable: aws_dynamodb.Table; @@ -100,8 +101,7 @@ export class UpdateDatasetApi { } private updateDatasetApi() { - const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-update`, { - functionName: `${this.baseId}-update`, + const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, { entry: `${this.src}/dataset_service`, architecture: Architecture.X86_64, runtime: Runtime.PYTHON_3_9, @@ -117,19 +117,51 @@ export class UpdateDatasetApi { }, layers: [this.layer], }); + + const requestModel = new Model(this.scope, `${this.baseId}-model`, { + restApi: this.router.api, + modelName: this.baseId, + description: `${this.baseId} Request Model`, + schema: { + schema: JsonSchemaVersion.DRAFT4, + title: this.baseId, + type: JsonSchemaType.OBJECT, + properties: { + status: { + type: JsonSchemaType.STRING, + minLength: 1, + }, + }, + required: [ + 'status', + ], + }, + contentType: 'application/json', + }); + + const requestValidator = new RequestValidator( + this.scope, + `${this.baseId}-validator`, + { + restApi: this.router.api, + requestValidatorName: this.baseId, + validateRequestBody: true, + }); + const createModelIntegration = new apigw.LambdaIntegration( lambdaFunction, { - proxy: false, - integrationResponses: [{ statusCode: '200' }], + proxy: true, }, ); - this.router.addMethod(this.httpMethod, createModelIntegration, { - apiKeyRequired: true, - methodResponses: [{ - statusCode: '200', - }], - }); + this.router.addResource('{id}') + .addMethod(this.httpMethod, createModelIntegration, { + apiKeyRequired: true, + requestValidator, + requestModels: { + 'application/json': requestModel, + } + }); } } diff --git a/infrastructure/src/sd-train/datasets-item-listall-api.ts b/infrastructure/src/sd-train/datasets-item-listall-api.ts index 807e6e36..bc8cb1a9 100644 --- a/infrastructure/src/sd-train/datasets-item-listall-api.ts +++ b/infrastructure/src/sd-train/datasets-item-listall-api.ts @@ -100,8 +100,7 @@ export class ListAllDatasetItemsApi { } private listAllDatasetApi() { - const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-listall`, { - functionName: `${this.baseId}-listall`, + const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, { entry: `${this.src}/dataset_service`, architecture: Architecture.X86_64, runtime: Runtime.PYTHON_3_9, @@ -122,38 +121,15 @@ export class ListAllDatasetItemsApi { const listDatasetItemsIntegration = new apigw.LambdaIntegration( lambdaFunction, { - proxy: false, - requestParameters: { - 'integration.request.path.dataset_name': 'method.request.path.dataset_name', - }, - requestTemplates: { - 'application/json': '{\n' + - ' "pathStringParameters": {\n' + - ' #foreach($pathParam in $input.params().path.keySet())\n' + - ' "$pathParam": "$util.escapeJavaScript($input.params().path.get($pathParam))"\n' + - ' #if($foreach.hasNext),#end\n' + - ' #end\n' + - ' }, \n' + - ' "x-auth": {\n' + - ' "username": "$context.authorizer.username",\n' + - ' "role": "$context.authorizer.role"\n' + - ' }\n' + - '}', - }, - integrationResponses: [{ statusCode: '200' }], + proxy: true, }, ); - const dataItemRouter = this.router.addResource('{dataset_name}'); - dataItemRouter.addResource('data').addMethod(this.httpMethod, listDatasetItemsIntegration, { - apiKeyRequired: true, - authorizer: this.authorizer, - requestParameters: { - 'method.request.path.dataset_name': true, - }, - methodResponses: [{ - statusCode: '200', - }, { statusCode: '500' }], - }); + + this.router.getResource('{id}') + ?.addMethod(this.httpMethod, listDatasetItemsIntegration, { + apiKeyRequired: true, + authorizer: this.authorizer, + }); } } diff --git a/infrastructure/src/sd-train/sd-train-deploy-stack.ts b/infrastructure/src/sd-train/sd-train-deploy-stack.ts index cfcb8403..eed2e03d 100644 --- a/infrastructure/src/sd-train/sd-train-deploy-stack.ts +++ b/infrastructure/src/sd-train/sd-train-deploy-stack.ts @@ -193,24 +193,24 @@ export class SdTrainDeployStack extends NestedStack { }); // POST /dataset - new CreateDatasetApi(this, 'sdExtn-createDataset', { + new CreateDatasetApi(this, 'CreateDataset', { commonLayer: commonLayer, datasetInfoTable: props.database.datasetInfoTable, datasetItemTable: props.database.datasetItemTable, httpMethod: 'POST', - router: routers.dataset, + router: routers.datasets, s3Bucket: s3Bucket, srcRoot: this.srcRoot, multiUserTable: multiUserTable, }); // PUT /dataset - new UpdateDatasetApi(this, 'sdExtn-updateDataset', { + const updateDataset = new UpdateDatasetApi(this, 'UpdateDataset', { commonLayer: commonLayer, datasetInfoTable: props.database.datasetInfoTable, datasetItemTable: props.database.datasetItemTable, httpMethod: 'PUT', - router: routers.dataset, + router: routers.datasets, s3Bucket: s3Bucket, srcRoot: this.srcRoot, }); @@ -228,13 +228,13 @@ export class SdTrainDeployStack extends NestedStack { }); // GET /dataset/{dataset_name}/data - new ListAllDatasetItemsApi(this, 'sdExtn-listallDsItems', { + new ListAllDatasetItemsApi(this, 'GetDataset', { commonLayer: commonLayer, datasetInfoTable: props.database.datasetInfoTable, datasetItemsTable: props.database.datasetItemTable, multiUserTable: multiUserTable, httpMethod: 'GET', - router: routers.dataset, + router: updateDataset.router, s3Bucket: s3Bucket, srcRoot: this.srcRoot, authorizer: props.authorizer, diff --git a/middleware_api/lambda/dataset_service/dataset_api.py b/middleware_api/lambda/dataset_service/dataset_api.py index c185a0d8..0f2f0f85 100644 --- a/middleware_api/lambda/dataset_service/dataset_api.py +++ b/middleware_api/lambda/dataset_service/dataset_api.py @@ -1,3 +1,4 @@ +import json import logging import os from dataclasses import dataclass @@ -6,7 +7,7 @@ from typing import Any, List from common.ddb_service.client import DynamoDbUtilsService from _types import DatasetItem, DatasetInfo, DatasetStatus, DataStatus -from common.response import ok, bad_request +from common.response import ok, bad_request, internal_server_error, not_found, forbidden from common.util import get_s3_presign_urls, generate_presign_url from multi_users.utils import get_permissions_by_username, get_user_roles, check_user_permissions @@ -47,16 +48,13 @@ class DatasetCreateEvent: # POST /dataset def create_dataset_api(raw_event, context): - event = DatasetCreateEvent(**raw_event) + event = DatasetCreateEvent(**json.loads(raw_event['body'])) try: creator_permissions = get_permissions_by_username(ddb_service, user_table, event.creator) if 'train' not in creator_permissions \ or ('all' not in creator_permissions['train'] and 'create' not in creator_permissions['train']): - return { - 'statusCode': 400, - 'errMsg': f'user {event.creator} has not permission to create a train job' - } + return bad_request(message=f'user {event.creator} has not permission to create a train job') user_roles = get_user_roles(ddb_service, user_table, event.creator) timestamp = datetime.now().timestamp() @@ -94,17 +92,16 @@ def create_dataset_api(raw_event, context): dataset_item_table: dataset, dataset_info_table: [new_dataset_info.__dict__] }) - return { - 'statusCode': 200, + + data = { 'datasetName': new_dataset_info.dataset_name, 's3PresignUrl': presign_url_map } + + return ok(data=data) except Exception as e: logger.error(e) - return { - 'statusCode': 500, - 'error': str(e) - } + return internal_server_error(message=str(e)) # GET /datasets @@ -158,49 +155,27 @@ def list_datasets_api(event, context): # GET /dataset/{dataset_name}/data def list_data_by_dataset(event, context): _filter = {} - if 'pathStringParameters' not in event: - return { - 'statusCode': 500, - 'error': 'path parameter /dataset/{dataset_name}/ are needed' - } - dataset_name = event['pathStringParameters']['dataset_name'] - if not dataset_name or len(dataset_name) == 0: - return { - 'statusCode': 500, - 'error': 'path parameter /dataset/{dataset_name}/ are needed' - } + dataset_name = event['pathParameters']['id'] dataset_info_rows = ddb_service.get_item(table=dataset_info_table, key_values={ 'dataset_name': dataset_name }) if not dataset_info_rows or len(dataset_info_rows) == 0: - return { - 'statusCode': 500, - 'error': 'path parameter /dataset/{dataset_name}/ are not found' - } + return not_found(message=f'dataset {dataset_name} is not found') dataset_info = DatasetInfo(**dataset_info_rows) - if 'x-auth' not in event or not event['x-auth']['username']: - return { - 'statusCode': 400, - 'error': 'no auth user provided' - } - - requestor_name = event['x-auth']['username'] - requestor_permissions = get_permissions_by_username(ddb_service, user_table, requestor_name) - requestor_roles = get_user_roles(ddb_service=ddb_service, user_table_name=user_table, username=requestor_name) + requester_name = event['requestContext']['authorizer']['username'] + requestor_permissions = get_permissions_by_username(ddb_service, user_table, requester_name) + requestor_roles = get_user_roles(ddb_service=ddb_service, user_table_name=user_table, username=requester_name) if not ( - (dataset_info.allowed_roles_or_users and check_user_permissions(dataset_info.allowed_roles_or_users, requestor_roles, requestor_name)) or # permission in dataset + (dataset_info.allowed_roles_or_users and check_user_permissions(dataset_info.allowed_roles_or_users, requestor_roles, requester_name)) or # permission in dataset (not dataset_info.allowed_roles_or_users and 'user' in requestor_permissions and 'all' in requestor_permissions['user']) # legacy data for super admin ): - return { - 'statusCode': 400, - 'error': 'no permission to view dataset' - } + return forbidden(message='no permission to view dataset') rows = ddb_service.query_items(table=dataset_item_table, key_values={ 'dataset_name': dataset_name @@ -218,8 +193,7 @@ def list_data_by_dataset(event, context): **item.params }) - return { - 'statusCode': 200, + return ok(data={ 'dataset_name': dataset_name, 'datasetName': dataset_info.dataset_name, 's3': f's3://{bucket_name}/{dataset_info.get_s3_key()}', @@ -227,27 +201,24 @@ def list_data_by_dataset(event, context): 'timestamp': dataset_info.timestamp, 'data': resp, **dataset_info.params - } + }, decimal=True) @dataclass class UpdateDatasetStatusEvent: - dataset_name: str status: str # PUT /dataset def update_dataset_status(raw_event, context): - event = UpdateDatasetStatusEvent(**raw_event) + event = UpdateDatasetStatusEvent(**json.loads(raw_event['body'])) + dataset_id = raw_event['pathParameters']['id'] try: raw_dataset_info = ddb_service.get_item(table=dataset_info_table, key_values={ - 'dataset_name': event.dataset_name + 'dataset_name': dataset_id }) if not raw_dataset_info or len(raw_dataset_info) == 0: - return { - 'statusCode': 404, - 'errorMsg': f'dataset {event.dataset_name} is not found' - } + return not_found(message=f'dataset {dataset_id} is not found') dataset_info = DatasetInfo(**raw_dataset_info) new_status = DatasetStatus[event.status] @@ -269,15 +240,11 @@ def update_dataset_status(raw_event, context): ddb_service.batch_put_items(table_items={ dataset_item_table: updates_items }) - return { - 'statusCode': 200, + return ok(data={ 'datasetName': dataset_info.dataset_name, 'status': dataset_info.dataset_status.value, - } + }) except Exception as e: logger.error(e) - return { - 'statusCode': 500, - 'error': str(e) - } + return internal_server_error(message=str(e))