Merge pull request #320 from elonniu/openapi

feat: update checkpoints/endpoints/datasets apis
pull/324/head
Elon Niu 2023-12-20 12:04:10 +08:00 committed by GitHub
commit b8174d31e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 421 additions and 305 deletions

View File

@ -229,14 +229,14 @@ class CloudApiManager:
if initial: if initial:
cloud_auth_manager.refresh() cloud_auth_manager.refresh()
raw_resp = requests.post(f'{cloud_auth_manager.api_url}user', raw_resp = requests.post(f'{cloud_auth_manager.api_url}users',
json=payload, json=payload,
headers=self._get_headers_by_user(user_token) headers=self._get_headers_by_user(user_token)
) )
raw_resp.raise_for_status() raw_resp.raise_for_status()
resp = raw_resp.json() resp = raw_resp.json()
if resp['statusCode'] != 200: if raw_resp.status_code != 200:
raise Exception(resp['errMsg']) raise Exception(resp['message'])
cloud_auth_manager.update_gradio_auth() cloud_auth_manager.update_gradio_auth()
return True return True
@ -304,10 +304,10 @@ class CloudApiManager:
if not self.auth_manger.enableAuth: if not self.auth_manger.enableAuth:
return [] return []
raw_response = requests.get(url=f'{self.auth_manger.api_url}dataset/{dataset_name}/data', headers=self._get_headers_by_user(user_token)) raw_response = requests.get(url=f'{self.auth_manger.api_url}datasets/{dataset_name}', headers=self._get_headers_by_user(user_token))
raw_response.raise_for_status() raw_response.raise_for_status()
# todo: the s3 presign url is not ready as content type to img # todo: the s3 presign url is not ready as content type to img
resp = raw_response.json() resp = raw_response.json()['data']
return resp return resp

View File

@ -534,7 +534,7 @@ def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_
api_key = get_variable_from_json('api_token') api_key = get_variable_from_json('api_token')
logger.info(f'!!!!!!api_gateway_url {api_gateway_url}') logger.info(f'!!!!!!api_gateway_url {api_gateway_url}')
url = str(api_gateway_url) + "checkpoint" url = str(api_gateway_url) + "checkpoints"
logger.debug(f"Post request for upload s3 presign url: {url}") logger.debug(f"Post request for upload s3 presign url: {url}")
@ -542,7 +542,7 @@ def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_
try: try:
response.raise_for_status() response.raise_for_status()
json_response = response.json() json_response = response.json()['data']
logger.debug(f"Response json {json_response}") logger.debug(f"Response json {json_response}")
s3_base = json_response["checkpoint"]["s3_location"] s3_base = json_response["checkpoint"]["s3_location"]
checkpoint_id = json_response["checkpoint"]["id"] checkpoint_id = json_response["checkpoint"]["id"]
@ -577,12 +577,11 @@ def sagemaker_upload_model_s3(sd_checkpoints_path, textual_inversion_path, lora_
) )
payload = { payload = {
"checkpoint_id": checkpoint_id,
"status": "Active", "status": "Active",
"multi_parts_tags": {local_tar_path: multiparts_tags} "multi_parts_tags": {local_tar_path: multiparts_tags}
} }
# Start creating model on cloud. # Start creating model on cloud.
response = requests.put(url=url, json=payload, headers={'x-api-key': api_key}) response = requests.put(url=f"{url}/{checkpoint_id}", json=payload, headers={'x-api-key': api_key})
s3_input_path = s3_base s3_input_path = s3_base
logger.debug(response) logger.debug(response)

View File

@ -340,7 +340,7 @@ def role_settings_tab():
if resp: if resp:
return f'Role upsert complete "{role_name}"' return f'Role upsert complete "{role_name}"'
except Exception as e: except Exception as e:
return f'User upsert failed: {e}' return f'Role upsert failed: {e}'
upsert_role_button.click(fn=upsert_role, upsert_role_button.click(fn=upsert_role,
inputs=[rolename_textbox, permissions_dropdown], inputs=[rolename_textbox, permissions_dropdown],
@ -792,12 +792,14 @@ def dataset_tab():
"creator": pr.username "creator": pr.username
} }
url = get_variable_from_json('api_gateway_url') + '/dataset' url = get_variable_from_json('api_gateway_url') + 'datasets'
api_key = get_variable_from_json('api_token') api_key = get_variable_from_json('api_token')
raw_response = requests.post(url=url, json=payload, headers={'x-api-key': api_key}) raw_response = requests.post(url=url, json=payload, headers={'x-api-key': api_key})
logger.info(raw_response.json())
raw_response.raise_for_status() raw_response.raise_for_status()
response = raw_response.json() response = raw_response.json()['data']
logger.info(f"Start upload sample files response:\n{response}") logger.info(f"Start upload sample files response:\n{response}")
for filename, presign_url in response['s3PresignUrl'].items(): for filename, presign_url in response['s3PresignUrl'].items():
@ -808,11 +810,10 @@ def dataset_tab():
response.raise_for_status() response.raise_for_status()
payload = { payload = {
"dataset_name": dataset_name,
"status": "Enabled" "status": "Enabled"
} }
raw_response = requests.put(url=url, json=payload, headers={'x-api-key': api_key}) raw_response = requests.put(url=f"{url}/{dataset_name}", json=payload, headers={'x-api-key': api_key})
raw_response.raise_for_status() raw_response.raise_for_status()
logger.debug(raw_response.json()) logger.debug(raw_response.json())
return f'Complete Dataset {dataset_name} creation', None, None, None, None return f'Complete Dataset {dataset_name} creation', None, None, None, None
@ -897,7 +898,7 @@ def update_connect_config(api_url, api_token, username=None, password=None, init
initial=initial, user_token=username): initial=initial, user_token=username):
return 'Initial Setup Failed' return 'Initial Setup Failed'
except Exception as e: except Exception as e:
return f'User upsert failed: {e}' return f'Initial Setup failed: {e}'
return "Setting updated" return "Setting updated"

View File

@ -192,7 +192,6 @@ export class CreateSagemakerEndpointsApi {
const role = this.iamRole(); const role = this.iamRole();
const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, <PythonFunctionProps>{ const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, <PythonFunctionProps>{
functionName: `${this.baseId}-api`,
entry: `${this.src}/inference_v2`, entry: `${this.src}/inference_v2`,
architecture: Architecture.X86_64, architecture: Architecture.X86_64,
runtime: Runtime.PYTHON_3_9, runtime: Runtime.PYTHON_3_9,
@ -213,14 +212,14 @@ export class CreateSagemakerEndpointsApi {
layers: [this.layer], layers: [this.layer],
}); });
const model = new Model(this.scope, 'CreateEndpointModel', { const model = new Model(this.scope, `${this.baseId}-model`, {
restApi: this.router.api, restApi: this.router.api,
contentType: 'application/json', contentType: 'application/json',
modelName: 'CreateEndpointModel', modelName: this.baseId,
description: 'Create Endpoint Model', description: `${this.baseId} Request Model`,
schema: { schema: {
schema: JsonSchemaVersion.DRAFT4, schema: JsonSchemaVersion.DRAFT4,
title: 'createEndpointSchema', title: this.baseId,
type: JsonSchemaType.OBJECT, type: JsonSchemaType.OBJECT,
properties: { properties: {
endpoint_name: { endpoint_name: {
@ -262,14 +261,13 @@ export class CreateSagemakerEndpointsApi {
const integration = new LambdaIntegration( const integration = new LambdaIntegration(
lambdaFunction, lambdaFunction,
{ {
proxy: false, proxy: true,
integrationResponses: [{ statusCode: '200' }],
}, },
); );
const requestValidator = new RequestValidator(this.scope, 'CreateEndpointRequestValidator', { const requestValidator = new RequestValidator(this.scope, `${this.baseId}-validator`, {
restApi: this.router.api, restApi: this.router.api,
requestValidatorName: 'CreateEndpointRequestValidator', requestValidatorName: this.baseId,
validateRequestBody: true, validateRequestBody: true,
validateRequestParameters: false, validateRequestParameters: false,
}); });
@ -281,11 +279,6 @@ export class CreateSagemakerEndpointsApi {
requestModels: { requestModels: {
'application/json': model, 'application/json': model,
}, },
methodResponses: [
{
statusCode: '200',
}, { statusCode: '500' },
],
}); });
} }

View File

@ -111,8 +111,7 @@ export class DeleteSagemakerEndpointsApi {
} }
private deleteEndpointsApi() { private deleteEndpointsApi() {
const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-delete-endpoints`, <PythonFunctionProps>{ const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, <PythonFunctionProps>{
functionName: `${this.baseId}-function`,
entry: `${this.src}/inference_v2`, entry: `${this.src}/inference_v2`,
architecture: Architecture.X86_64, architecture: Architecture.X86_64,
runtime: Runtime.PYTHON_3_9, runtime: Runtime.PYTHON_3_9,
@ -128,13 +127,13 @@ export class DeleteSagemakerEndpointsApi {
layers: [this.layer], layers: [this.layer],
}); });
const model = new Model(this.scope, 'DeleteEndpointsModel', { const model = new Model(this.scope, `${this.baseId}-model`, {
restApi: this.router.api, restApi: this.router.api,
modelName: 'DeleteEndpointsModel', modelName: this.baseId,
description: 'Delete Endpoint Model', description: `${this.baseId} Request Model`,
schema: { schema: {
schema: JsonSchemaVersion.DRAFT4, schema: JsonSchemaVersion.DRAFT4,
title: 'deleteEndpointSchema', title: this.baseId,
type: JsonSchemaType.OBJECT, type: JsonSchemaType.OBJECT,
properties: { properties: {
endpoint_name_list: { endpoint_name_list: {
@ -165,9 +164,9 @@ export class DeleteSagemakerEndpointsApi {
}, },
); );
const requestValidator = new RequestValidator(this.scope, 'DeleteEndpointRequestValidator', { const requestValidator = new RequestValidator(this.scope, `${this.baseId}-validator`, {
restApi: this.router.api, restApi: this.router.api,
requestValidatorName: 'DeleteEndpointRequestValidator', requestValidatorName: this.baseId,
validateRequestBody: true, validateRequestBody: true,
}); });

View File

@ -122,7 +122,7 @@ export class SDAsyncInferenceStack extends NestedStack {
); );
new DeleteSagemakerEndpointsApi( new DeleteSagemakerEndpointsApi(
this, 'sd-infer-v2-deleteEndpoints', this, 'DeleteEndpoints',
<DeleteSagemakerEndpointsApiProps>{ <DeleteSagemakerEndpointsApiProps>{
router: props.routers.endpoints, router: props.routers.endpoints,
commonLayer: props.commonLayer, commonLayer: props.commonLayer,
@ -164,7 +164,7 @@ export class SDAsyncInferenceStack extends NestedStack {
const inference_result_error_topic = aws_sns.Topic.fromTopicArn(scope, `${id}-infer-result-err-tp`, props.inferenceErrorTopic.topicArn); const inference_result_error_topic = aws_sns.Topic.fromTopicArn(scope, `${id}-infer-result-err-tp`, props.inferenceErrorTopic.topicArn);
new CreateSagemakerEndpointsApi( new CreateSagemakerEndpointsApi(
this, 'sd-infer-v2-createEndpoint', this, 'CreateEndpoint',
<CreateSagemakerEndpointsApiProps>{ <CreateSagemakerEndpointsApiProps>{
router: props.routers.endpoints, router: props.routers.endpoints,
commonLayer: props.commonLayer, commonLayer: props.commonLayer,

View File

@ -11,6 +11,7 @@ import { MethodOptions } from 'aws-cdk-lib/aws-apigateway/lib/method';
import { Effect } from 'aws-cdk-lib/aws-iam'; import { Effect } from 'aws-cdk-lib/aws-iam';
import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda'; import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda';
import { Construct } from 'constructs'; import { Construct } from 'constructs';
import {JsonSchemaType, JsonSchemaVersion, Model, RequestValidator} from "aws-cdk-lib/aws-apigateway";
export interface CreateCheckPointApiProps { export interface CreateCheckPointApiProps {
@ -101,8 +102,7 @@ export class CreateCheckPointApi {
} }
private createCheckpointApi() { private createCheckpointApi() {
const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-create`, <PythonFunctionProps>{ const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, <PythonFunctionProps>{
functionName: `${this.baseId}-create-checkpoint`,
entry: `${this.src}/model_and_train`, entry: `${this.src}/model_and_train`,
architecture: Architecture.X86_64, architecture: Architecture.X86_64,
runtime: Runtime.PYTHON_3_9, runtime: Runtime.PYTHON_3_9,
@ -118,30 +118,80 @@ export class CreateCheckPointApi {
}, },
layers: [this.layer], layers: [this.layer],
}); });
const requestModel = new Model(this.scope, `${this.baseId}-model`,{
restApi: this.router.api,
modelName: this.baseId,
description: `${this.baseId} Request Model`,
schema: {
schema: JsonSchemaVersion.DRAFT4,
title: this.baseId,
type: JsonSchemaType.OBJECT,
properties: {
checkpoint_type: {
type: JsonSchemaType.STRING,
minLength: 1,
},
filenames: {
type: JsonSchemaType.ARRAY,
items: {
type: JsonSchemaType.OBJECT,
properties: {
filename: {
type: JsonSchemaType.STRING,
minLength: 1,
},
parts_number: {
type: JsonSchemaType.INTEGER,
minimum: 1,
maximum: 100,
},
},
},
minItems: 1,
maxItems: 20,
},
params: {
type: JsonSchemaType.OBJECT,
properties: {
message: {
type: JsonSchemaType.STRING,
},
creator: {
type: JsonSchemaType.STRING,
}
},
},
},
required: [
'checkpoint_type',
'filenames',
],
},
contentType: 'application/json',
});
const requestValidator = new RequestValidator(
this.scope,
`${this.baseId}-validator`,
{
restApi: this.router.api,
requestValidatorName: this.baseId,
validateRequestBody: true,
});
const createCheckpointIntegration = new apigw.LambdaIntegration( const createCheckpointIntegration = new apigw.LambdaIntegration(
lambdaFunction, lambdaFunction,
{ {
proxy: false, proxy: true,
integrationResponses: [{
statusCode: '200',
responseParameters: {
'method.response.header.Access-Control-Allow-Headers': "'Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token,X-Amz-User-Agent'",
'method.response.header.Access-Control-Allow-Methods': "'GET,POST,PUT,OPTIONS'",
'method.response.header.Access-Control-Allow-Origin': "'*'",
},
}],
}, },
); );
this.router.addMethod(this.httpMethod, createCheckpointIntegration, <MethodOptions>{ this.router.addMethod(this.httpMethod, createCheckpointIntegration, <MethodOptions>{
apiKeyRequired: true, apiKeyRequired: true,
methodResponses: [{ requestValidator,
statusCode: '200', requestModels: {
responseParameters: { 'application/json': requestModel,
'method.response.header.Access-Control-Allow-Headers': true, }
'method.response.header.Access-Control-Allow-Methods': true,
'method.response.header.Access-Control-Allow-Origin': true,
},
}],
}); });
} }
} }

View File

@ -7,10 +7,10 @@ import {
aws_lambda, aws_s3, aws_lambda, aws_s3,
Duration, Duration,
} from 'aws-cdk-lib'; } from 'aws-cdk-lib';
import { MethodOptions } from 'aws-cdk-lib/aws-apigateway/lib/method';
import { Effect } from 'aws-cdk-lib/aws-iam'; import { Effect } from 'aws-cdk-lib/aws-iam';
import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda'; import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda';
import { Construct } from 'constructs'; import { Construct } from 'constructs';
import {JsonSchemaType, JsonSchemaVersion, Model, RequestValidator} from "aws-cdk-lib/aws-apigateway";
export interface UpdateCheckPointApiProps { export interface UpdateCheckPointApiProps {
@ -93,8 +93,7 @@ export class UpdateCheckPointApi {
} }
private updateCheckpointApi() { private updateCheckpointApi() {
const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-update`, <PythonFunctionProps>{ const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, <PythonFunctionProps>{
functionName: `${this.baseId}-update-checkpoint`,
entry: `${this.src}/model_and_train`, entry: `${this.src}/model_and_train`,
architecture: Architecture.X86_64, architecture: Architecture.X86_64,
runtime: Runtime.PYTHON_3_9, runtime: Runtime.PYTHON_3_9,
@ -109,31 +108,56 @@ export class UpdateCheckPointApi {
}, },
layers: [this.layer], layers: [this.layer],
}); });
const requestModel = new Model(this.scope, `${this.baseId}-model`,{
restApi: this.router.api,
modelName: this.baseId,
description: `${this.baseId} Request Model`,
schema: {
schema: JsonSchemaVersion.DRAFT4,
title: this.baseId,
type: JsonSchemaType.OBJECT,
properties: {
status: {
type: JsonSchemaType.STRING,
minLength: 1,
},
multi_parts_tags: {
type: JsonSchemaType.OBJECT,
},
},
required: [
'status',
'multi_parts_tags',
],
},
contentType: 'application/json',
});
const requestValidator = new RequestValidator(
this.scope,
`${this.baseId}-validator`,
{
restApi: this.router.api,
requestValidatorName: this.baseId,
validateRequestBody: true,
});
const createModelIntegration = new apigw.LambdaIntegration( const createModelIntegration = new apigw.LambdaIntegration(
lambdaFunction, lambdaFunction,
{ {
proxy: false, proxy: true,
integrationResponses: [{
statusCode: '200',
responseParameters: {
'method.response.header.Access-Control-Allow-Headers': "'Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token,X-Amz-User-Agent'",
'method.response.header.Access-Control-Allow-Methods': "'GET,POST,PUT,OPTIONS'",
'method.response.header.Access-Control-Allow-Origin': "'*'",
},
}],
}, },
); );
this.router.addMethod(this.httpMethod, createModelIntegration, <MethodOptions>{ this.router.addResource('{id}')
apiKeyRequired: true, .addMethod(this.httpMethod, createModelIntegration,
methodResponses: [{ {
statusCode: '200', apiKeyRequired: true,
responseParameters: { requestValidator,
'method.response.header.Access-Control-Allow-Headers': true, requestModels: {
'method.response.header.Access-Control-Allow-Methods': true, 'application/json': requestModel,
'method.response.header.Access-Control-Allow-Origin': true, }
}, });
}],
});
} }
} }

View File

@ -11,6 +11,7 @@ import { MethodOptions } from 'aws-cdk-lib/aws-apigateway/lib/method';
import { Effect } from 'aws-cdk-lib/aws-iam'; import { Effect } from 'aws-cdk-lib/aws-iam';
import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda'; import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda';
import { Construct } from 'constructs'; import { Construct } from 'constructs';
import {JsonSchemaType, JsonSchemaVersion, Model, RequestValidator} from "aws-cdk-lib/aws-apigateway";
export interface CreateDatasetApiProps { export interface CreateDatasetApiProps {
@ -108,8 +109,7 @@ export class CreateDatasetApi {
} }
private createDatasetApi() { private createDatasetApi() {
const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-create`, <PythonFunctionProps>{ const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, <PythonFunctionProps>{
functionName: `${this.baseId}-create`,
entry: `${this.src}/dataset_service`, entry: `${this.src}/dataset_service`,
architecture: Architecture.X86_64, architecture: Architecture.X86_64,
runtime: Runtime.PYTHON_3_9, runtime: Runtime.PYTHON_3_9,
@ -126,18 +126,88 @@ export class CreateDatasetApi {
}, },
layers: [this.layer], layers: [this.layer],
}); });
const requestModel = new Model(this.scope, `${this.baseId}-model`, {
restApi: this.router.api,
modelName: this.baseId,
description: `${this.baseId} Request Model`,
schema: {
schema: JsonSchemaVersion.DRAFT4,
title: this.baseId,
type: JsonSchemaType.OBJECT,
properties: {
dataset_name: {
type: JsonSchemaType.STRING,
minLength: 1,
},
content: {
type: JsonSchemaType.ARRAY,
items: {
type: JsonSchemaType.OBJECT,
properties: {
filename: {
type: JsonSchemaType.STRING,
minLength: 1,
},
name: {
type: JsonSchemaType.STRING,
minLength: 1,
},
type: {
type: JsonSchemaType.STRING,
minLength: 1,
},
params: {
type: JsonSchemaType.OBJECT,
},
},
},
minItems: 1,
maxItems: 100,
},
creator: {
type: JsonSchemaType.STRING,
minLength: 1,
},
params: {
type: JsonSchemaType.OBJECT,
properties: {
description: {
type: JsonSchemaType.STRING,
},
}
},
},
required: [
'dataset_name',
'content',
'creator',
],
},
contentType: 'application/json',
});
const requestValidator = new RequestValidator(
this.scope,
`${this.baseId}-validator`,
{
restApi: this.router.api,
requestValidatorName: this.baseId,
validateRequestBody: true,
});
const createDatasetIntegration = new apigw.LambdaIntegration( const createDatasetIntegration = new apigw.LambdaIntegration(
lambdaFunction, lambdaFunction,
{ {
proxy: false, proxy: true,
integrationResponses: [{ statusCode: '200' }],
}, },
); );
this.router.addMethod(this.httpMethod, createDatasetIntegration, <MethodOptions>{ this.router.addMethod(this.httpMethod, createDatasetIntegration, <MethodOptions>{
apiKeyRequired: true, apiKeyRequired: true,
methodResponses: [{ requestValidator,
statusCode: '200', requestModels: {
}], 'application/json': requestModel,
}
}); });
} }
} }

View File

@ -11,6 +11,7 @@ import { MethodOptions } from 'aws-cdk-lib/aws-apigateway/lib/method';
import { Effect } from 'aws-cdk-lib/aws-iam'; import { Effect } from 'aws-cdk-lib/aws-iam';
import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda'; import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda';
import { Construct } from 'constructs'; import { Construct } from 'constructs';
import {JsonSchemaType, JsonSchemaVersion, Model, RequestValidator} from "aws-cdk-lib/aws-apigateway";
export interface UpdateDatasetApiProps { export interface UpdateDatasetApiProps {
@ -25,7 +26,7 @@ export interface UpdateDatasetApiProps {
export class UpdateDatasetApi { export class UpdateDatasetApi {
private readonly src; private readonly src;
private readonly router: aws_apigateway.Resource; public readonly router: aws_apigateway.Resource;
private readonly httpMethod: string; private readonly httpMethod: string;
private readonly scope: Construct; private readonly scope: Construct;
private readonly datasetInfoTable: aws_dynamodb.Table; private readonly datasetInfoTable: aws_dynamodb.Table;
@ -100,8 +101,7 @@ export class UpdateDatasetApi {
} }
private updateDatasetApi() { private updateDatasetApi() {
const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-update`, <PythonFunctionProps>{ const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, <PythonFunctionProps>{
functionName: `${this.baseId}-update`,
entry: `${this.src}/dataset_service`, entry: `${this.src}/dataset_service`,
architecture: Architecture.X86_64, architecture: Architecture.X86_64,
runtime: Runtime.PYTHON_3_9, runtime: Runtime.PYTHON_3_9,
@ -117,19 +117,51 @@ export class UpdateDatasetApi {
}, },
layers: [this.layer], layers: [this.layer],
}); });
const requestModel = new Model(this.scope, `${this.baseId}-model`, {
restApi: this.router.api,
modelName: this.baseId,
description: `${this.baseId} Request Model`,
schema: {
schema: JsonSchemaVersion.DRAFT4,
title: this.baseId,
type: JsonSchemaType.OBJECT,
properties: {
status: {
type: JsonSchemaType.STRING,
minLength: 1,
},
},
required: [
'status',
],
},
contentType: 'application/json',
});
const requestValidator = new RequestValidator(
this.scope,
`${this.baseId}-validator`,
{
restApi: this.router.api,
requestValidatorName: this.baseId,
validateRequestBody: true,
});
const createModelIntegration = new apigw.LambdaIntegration( const createModelIntegration = new apigw.LambdaIntegration(
lambdaFunction, lambdaFunction,
{ {
proxy: false, proxy: true,
integrationResponses: [{ statusCode: '200' }],
}, },
); );
this.router.addMethod(this.httpMethod, createModelIntegration, <MethodOptions>{ this.router.addResource('{id}')
apiKeyRequired: true, .addMethod(this.httpMethod, createModelIntegration, <MethodOptions>{
methodResponses: [{ apiKeyRequired: true,
statusCode: '200', requestValidator,
}], requestModels: {
}); 'application/json': requestModel,
}
});
} }
} }

View File

@ -100,8 +100,7 @@ export class ListAllDatasetItemsApi {
} }
private listAllDatasetApi() { private listAllDatasetApi() {
const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-listall`, <PythonFunctionProps>{ const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, <PythonFunctionProps>{
functionName: `${this.baseId}-listall`,
entry: `${this.src}/dataset_service`, entry: `${this.src}/dataset_service`,
architecture: Architecture.X86_64, architecture: Architecture.X86_64,
runtime: Runtime.PYTHON_3_9, runtime: Runtime.PYTHON_3_9,
@ -122,38 +121,15 @@ export class ListAllDatasetItemsApi {
const listDatasetItemsIntegration = new apigw.LambdaIntegration( const listDatasetItemsIntegration = new apigw.LambdaIntegration(
lambdaFunction, lambdaFunction,
{ {
proxy: false, proxy: true,
requestParameters: {
'integration.request.path.dataset_name': 'method.request.path.dataset_name',
},
requestTemplates: {
'application/json': '{\n' +
' "pathStringParameters": {\n' +
' #foreach($pathParam in $input.params().path.keySet())\n' +
' "$pathParam": "$util.escapeJavaScript($input.params().path.get($pathParam))"\n' +
' #if($foreach.hasNext),#end\n' +
' #end\n' +
' }, \n' +
' "x-auth": {\n' +
' "username": "$context.authorizer.username",\n' +
' "role": "$context.authorizer.role"\n' +
' }\n' +
'}',
},
integrationResponses: [{ statusCode: '200' }],
}, },
); );
const dataItemRouter = this.router.addResource('{dataset_name}');
dataItemRouter.addResource('data').addMethod(this.httpMethod, listDatasetItemsIntegration, <MethodOptions>{ this.router.getResource('{id}')
apiKeyRequired: true, ?.addMethod(this.httpMethod, listDatasetItemsIntegration, <MethodOptions>{
authorizer: this.authorizer, apiKeyRequired: true,
requestParameters: { authorizer: this.authorizer,
'method.request.path.dataset_name': true, });
},
methodResponses: [{
statusCode: '200',
}, { statusCode: '500' }],
});
} }
} }

View File

@ -172,45 +172,45 @@ export class SdTrainDeployStack extends NestedStack {
// POST /checkpoint // POST /checkpoint
new CreateCheckPointApi(this, 'sdExtn-createCkpt', { new CreateCheckPointApi(this, 'CreateCheckPoint', {
checkpointTable: checkPointTable, checkpointTable: checkPointTable,
commonLayer: commonLayer, commonLayer: commonLayer,
httpMethod: 'POST', httpMethod: 'POST',
router: routers.checkpoint, router: routers.checkpoints,
s3Bucket: s3Bucket, s3Bucket: s3Bucket,
srcRoot: this.srcRoot, srcRoot: this.srcRoot,
multiUserTable: multiUserTable, multiUserTable: multiUserTable,
}); });
// PUT /checkpoint // PUT /checkpoint
new UpdateCheckPointApi(this, 'sdExtn-updateCkpt', { new UpdateCheckPointApi(this, 'UpdateCheckPoint', {
checkpointTable: checkPointTable, checkpointTable: checkPointTable,
commonLayer: commonLayer, commonLayer: commonLayer,
httpMethod: 'PUT', httpMethod: 'PUT',
router: routers.checkpoint, router: routers.checkpoints,
s3Bucket: s3Bucket, s3Bucket: s3Bucket,
srcRoot: this.srcRoot, srcRoot: this.srcRoot,
}); });
// POST /dataset // POST /dataset
new CreateDatasetApi(this, 'sdExtn-createDataset', { new CreateDatasetApi(this, 'CreateDataset', {
commonLayer: commonLayer, commonLayer: commonLayer,
datasetInfoTable: props.database.datasetInfoTable, datasetInfoTable: props.database.datasetInfoTable,
datasetItemTable: props.database.datasetItemTable, datasetItemTable: props.database.datasetItemTable,
httpMethod: 'POST', httpMethod: 'POST',
router: routers.dataset, router: routers.datasets,
s3Bucket: s3Bucket, s3Bucket: s3Bucket,
srcRoot: this.srcRoot, srcRoot: this.srcRoot,
multiUserTable: multiUserTable, multiUserTable: multiUserTable,
}); });
// PUT /dataset // PUT /dataset
new UpdateDatasetApi(this, 'sdExtn-updateDataset', { const updateDataset = new UpdateDatasetApi(this, 'UpdateDataset', {
commonLayer: commonLayer, commonLayer: commonLayer,
datasetInfoTable: props.database.datasetInfoTable, datasetInfoTable: props.database.datasetInfoTable,
datasetItemTable: props.database.datasetItemTable, datasetItemTable: props.database.datasetItemTable,
httpMethod: 'PUT', httpMethod: 'PUT',
router: routers.dataset, router: routers.datasets,
s3Bucket: s3Bucket, s3Bucket: s3Bucket,
srcRoot: this.srcRoot, srcRoot: this.srcRoot,
}); });
@ -228,13 +228,13 @@ export class SdTrainDeployStack extends NestedStack {
}); });
// GET /dataset/{dataset_name}/data // GET /dataset/{dataset_name}/data
new ListAllDatasetItemsApi(this, 'sdExtn-listallDsItems', { new ListAllDatasetItemsApi(this, 'GetDataset', {
commonLayer: commonLayer, commonLayer: commonLayer,
datasetInfoTable: props.database.datasetInfoTable, datasetInfoTable: props.database.datasetInfoTable,
datasetItemsTable: props.database.datasetItemTable, datasetItemsTable: props.database.datasetItemTable,
multiUserTable: multiUserTable, multiUserTable: multiUserTable,
httpMethod: 'GET', httpMethod: 'GET',
router: routers.dataset, router: updateDataset.router,
s3Bucket: s3Bucket, s3Bucket: s3Bucket,
srcRoot: this.srcRoot, srcRoot: this.srcRoot,
authorizer: props.authorizer, authorizer: props.authorizer,

View File

@ -31,7 +31,7 @@ export class MultiUsersStack extends NestedStack {
constructor(scope: Construct, id: string, props: MultiUsersStackProps) { constructor(scope: Construct, id: string, props: MultiUsersStackProps) {
super(scope, id, props); super(scope, id, props);
new RoleUpsertApi(scope, 'roleUpsert', { new RoleUpsertApi(scope, 'CreateRole', {
commonLayer: props.commonLayer, commonLayer: props.commonLayer,
httpMethod: 'POST', httpMethod: 'POST',
multiUserTable: props.multiUserTable, multiUserTable: props.multiUserTable,
@ -48,12 +48,12 @@ export class MultiUsersStack extends NestedStack {
authorizer: props.authorizer, authorizer: props.authorizer,
}); });
new UserUpsertApi(scope, 'userUpsert', { new UserUpsertApi(scope, 'CreateUser', {
commonLayer: props.commonLayer, commonLayer: props.commonLayer,
httpMethod: 'POST', httpMethod: 'POST',
multiUserTable: props.multiUserTable, multiUserTable: props.multiUserTable,
passwordKey: props.passwordKeyAlias, passwordKey: props.passwordKeyAlias,
router: props.routers.user, router: props.routers.users,
srcRoot: this.srcRoot, srcRoot: this.srcRoot,
authorizer: props.authorizer, authorizer: props.authorizer,
}); });

View File

@ -11,6 +11,7 @@ import { MethodOptions } from 'aws-cdk-lib/aws-apigateway/lib/method';
import { Effect } from 'aws-cdk-lib/aws-iam'; import { Effect } from 'aws-cdk-lib/aws-iam';
import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda'; import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda';
import { Construct } from 'constructs'; import { Construct } from 'constructs';
import {JsonSchemaType, JsonSchemaVersion, Model, RequestValidator} from "aws-cdk-lib/aws-apigateway";
export interface UserUpsertApiProps { export interface UserUpsertApiProps {
router: aws_apigateway.Resource; router: aws_apigateway.Resource;
@ -95,7 +96,6 @@ export class UserUpsertApi {
private upsertUserApi() { private upsertUserApi() {
const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, <PythonFunctionProps>{ const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, <PythonFunctionProps>{
functionName: `${this.baseId}-upsert`,
entry: `${this.src}/multi_users`, entry: `${this.src}/multi_users`,
architecture: Architecture.X86_64, architecture: Architecture.X86_64,
runtime: Runtime.PYTHON_3_9, runtime: Runtime.PYTHON_3_9,
@ -110,28 +110,72 @@ export class UserUpsertApi {
}, },
layers: [this.layer], layers: [this.layer],
}); });
const requestModel = new Model(this.scope, `${this.baseId}-model`,{
restApi: this.router.api,
modelName: this.baseId,
description: `${this.baseId} Request Model`,
schema: {
schema: JsonSchemaVersion.DRAFT4,
title: this.baseId,
type: JsonSchemaType.OBJECT,
properties: {
username: {
type: JsonSchemaType.STRING,
minLength: 1,
},
password: {
type: JsonSchemaType.STRING,
minLength: 1,
},
creator: {
type: JsonSchemaType.STRING,
minLength: 1,
},
initial: {
type: JsonSchemaType.BOOLEAN,
default: false,
},
roles: {
type: JsonSchemaType.ARRAY,
items: {
type: JsonSchemaType.STRING,
minLength: 1,
},
minItems: 1,
maxItems: 20,
},
},
required: [
'username',
'creator',
],
},
contentType: 'application/json',
});
const requestValidator = new RequestValidator(
this.scope,
`${this.baseId}-validator`,
{
restApi: this.router.api,
requestValidatorName: this.baseId,
validateRequestBody: true,
});
const upsertUserIntegration = new apigw.LambdaIntegration( const upsertUserIntegration = new apigw.LambdaIntegration(
lambdaFunction, lambdaFunction,
{ {
proxy: false, proxy: true,
requestTemplates: {
'application/json': '{\n' +
' "body": $input.json("$"),' +
' "x-auth": {\n' +
' "username": "$context.authorizer.username",\n' +
' "role": "$context.authorizer.role"\n' +
' }\n' +
'}',
},
integrationResponses: [{ statusCode: '200' }],
}, },
); );
this.router.addMethod(this.httpMethod, upsertUserIntegration, <MethodOptions>{ this.router.addMethod(this.httpMethod, upsertUserIntegration, <MethodOptions>{
apiKeyRequired: true, apiKeyRequired: true,
authorizer: this.authorizer, authorizer: this.authorizer,
methodResponses: [{ requestValidator,
statusCode: '200', requestModels: {
}], 'application/json': requestModel,
}
}); });
} }
} }

View File

@ -1,3 +1,4 @@
import json
import logging import logging
import os import os
from dataclasses import dataclass from dataclasses import dataclass
@ -6,7 +7,7 @@ from typing import Any, List
from common.ddb_service.client import DynamoDbUtilsService from common.ddb_service.client import DynamoDbUtilsService
from _types import DatasetItem, DatasetInfo, DatasetStatus, DataStatus from _types import DatasetItem, DatasetInfo, DatasetStatus, DataStatus
from common.response import ok, bad_request from common.response import ok, bad_request, internal_server_error, not_found, forbidden
from common.util import get_s3_presign_urls, generate_presign_url from common.util import get_s3_presign_urls, generate_presign_url
from multi_users.utils import get_permissions_by_username, get_user_roles, check_user_permissions from multi_users.utils import get_permissions_by_username, get_user_roles, check_user_permissions
@ -47,16 +48,13 @@ class DatasetCreateEvent:
# POST /dataset # POST /dataset
def create_dataset_api(raw_event, context): def create_dataset_api(raw_event, context):
event = DatasetCreateEvent(**raw_event) event = DatasetCreateEvent(**json.loads(raw_event['body']))
try: try:
creator_permissions = get_permissions_by_username(ddb_service, user_table, event.creator) creator_permissions = get_permissions_by_username(ddb_service, user_table, event.creator)
if 'train' not in creator_permissions \ if 'train' not in creator_permissions \
or ('all' not in creator_permissions['train'] and 'create' not in creator_permissions['train']): or ('all' not in creator_permissions['train'] and 'create' not in creator_permissions['train']):
return { return bad_request(message=f'user {event.creator} has not permission to create a train job')
'statusCode': 400,
'errMsg': f'user {event.creator} has not permission to create a train job'
}
user_roles = get_user_roles(ddb_service, user_table, event.creator) user_roles = get_user_roles(ddb_service, user_table, event.creator)
timestamp = datetime.now().timestamp() timestamp = datetime.now().timestamp()
@ -94,17 +92,16 @@ def create_dataset_api(raw_event, context):
dataset_item_table: dataset, dataset_item_table: dataset,
dataset_info_table: [new_dataset_info.__dict__] dataset_info_table: [new_dataset_info.__dict__]
}) })
return {
'statusCode': 200, data = {
'datasetName': new_dataset_info.dataset_name, 'datasetName': new_dataset_info.dataset_name,
's3PresignUrl': presign_url_map 's3PresignUrl': presign_url_map
} }
return ok(data=data)
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
return { return internal_server_error(message=str(e))
'statusCode': 500,
'error': str(e)
}
# GET /datasets # GET /datasets
@ -158,49 +155,27 @@ def list_datasets_api(event, context):
# GET /dataset/{dataset_name}/data # GET /dataset/{dataset_name}/data
def list_data_by_dataset(event, context): def list_data_by_dataset(event, context):
_filter = {} _filter = {}
if 'pathStringParameters' not in event:
return {
'statusCode': 500,
'error': 'path parameter /dataset/{dataset_name}/ are needed'
}
dataset_name = event['pathStringParameters']['dataset_name'] dataset_name = event['pathParameters']['id']
if not dataset_name or len(dataset_name) == 0:
return {
'statusCode': 500,
'error': 'path parameter /dataset/{dataset_name}/ are needed'
}
dataset_info_rows = ddb_service.get_item(table=dataset_info_table, key_values={ dataset_info_rows = ddb_service.get_item(table=dataset_info_table, key_values={
'dataset_name': dataset_name 'dataset_name': dataset_name
}) })
if not dataset_info_rows or len(dataset_info_rows) == 0: if not dataset_info_rows or len(dataset_info_rows) == 0:
return { return not_found(message=f'dataset {dataset_name} is not found')
'statusCode': 500,
'error': 'path parameter /dataset/{dataset_name}/ are not found'
}
dataset_info = DatasetInfo(**dataset_info_rows) dataset_info = DatasetInfo(**dataset_info_rows)
if 'x-auth' not in event or not event['x-auth']['username']: requester_name = event['requestContext']['authorizer']['username']
return { requestor_permissions = get_permissions_by_username(ddb_service, user_table, requester_name)
'statusCode': 400, requestor_roles = get_user_roles(ddb_service=ddb_service, user_table_name=user_table, username=requester_name)
'error': 'no auth user provided'
}
requestor_name = event['x-auth']['username']
requestor_permissions = get_permissions_by_username(ddb_service, user_table, requestor_name)
requestor_roles = get_user_roles(ddb_service=ddb_service, user_table_name=user_table, username=requestor_name)
if not ( if not (
(dataset_info.allowed_roles_or_users and check_user_permissions(dataset_info.allowed_roles_or_users, requestor_roles, requestor_name)) or # permission in dataset (dataset_info.allowed_roles_or_users and check_user_permissions(dataset_info.allowed_roles_or_users, requestor_roles, requester_name)) or # permission in dataset
(not dataset_info.allowed_roles_or_users and 'user' in requestor_permissions and 'all' in requestor_permissions['user']) # legacy data for super admin (not dataset_info.allowed_roles_or_users and 'user' in requestor_permissions and 'all' in requestor_permissions['user']) # legacy data for super admin
): ):
return { return forbidden(message='no permission to view dataset')
'statusCode': 400,
'error': 'no permission to view dataset'
}
rows = ddb_service.query_items(table=dataset_item_table, key_values={ rows = ddb_service.query_items(table=dataset_item_table, key_values={
'dataset_name': dataset_name 'dataset_name': dataset_name
@ -218,8 +193,7 @@ def list_data_by_dataset(event, context):
**item.params **item.params
}) })
return { return ok(data={
'statusCode': 200,
'dataset_name': dataset_name, 'dataset_name': dataset_name,
'datasetName': dataset_info.dataset_name, 'datasetName': dataset_info.dataset_name,
's3': f's3://{bucket_name}/{dataset_info.get_s3_key()}', 's3': f's3://{bucket_name}/{dataset_info.get_s3_key()}',
@ -227,27 +201,24 @@ def list_data_by_dataset(event, context):
'timestamp': dataset_info.timestamp, 'timestamp': dataset_info.timestamp,
'data': resp, 'data': resp,
**dataset_info.params **dataset_info.params
} }, decimal=True)
@dataclass @dataclass
class UpdateDatasetStatusEvent: class UpdateDatasetStatusEvent:
dataset_name: str
status: str status: str
# PUT /dataset # PUT /dataset
def update_dataset_status(raw_event, context): def update_dataset_status(raw_event, context):
event = UpdateDatasetStatusEvent(**raw_event) event = UpdateDatasetStatusEvent(**json.loads(raw_event['body']))
dataset_id = raw_event['pathParameters']['id']
try: try:
raw_dataset_info = ddb_service.get_item(table=dataset_info_table, key_values={ raw_dataset_info = ddb_service.get_item(table=dataset_info_table, key_values={
'dataset_name': event.dataset_name 'dataset_name': dataset_id
}) })
if not raw_dataset_info or len(raw_dataset_info) == 0: if not raw_dataset_info or len(raw_dataset_info) == 0:
return { return not_found(message=f'dataset {dataset_id} is not found')
'statusCode': 404,
'errorMsg': f'dataset {event.dataset_name} is not found'
}
dataset_info = DatasetInfo(**raw_dataset_info) dataset_info = DatasetInfo(**raw_dataset_info)
new_status = DatasetStatus[event.status] new_status = DatasetStatus[event.status]
@ -269,15 +240,11 @@ def update_dataset_status(raw_event, context):
ddb_service.batch_put_items(table_items={ ddb_service.batch_put_items(table_items={
dataset_item_table: updates_items dataset_item_table: updates_items
}) })
return { return ok(data={
'statusCode': 200,
'datasetName': dataset_info.dataset_name, 'datasetName': dataset_info.dataset_name,
'status': dataset_info.dataset_status.value, 'status': dataset_info.dataset_status.value,
} })
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
return { return internal_server_error(message=str(e))
'statusCode': 500,
'error': str(e)
}

View File

@ -147,7 +147,7 @@ class CreateEndpointEvent:
def sagemaker_endpoint_create_api(raw_event, ctx): def sagemaker_endpoint_create_api(raw_event, ctx):
logger.info(f"Received event: {raw_event}") logger.info(f"Received event: {raw_event}")
logger.info(f"Received ctx: {ctx}") logger.info(f"Received ctx: {ctx}")
event = CreateEndpointEvent(**raw_event) event = CreateEndpointEvent(**json.loads(raw_event['body']))
try: try:
endpoint_deployment_id = str(uuid.uuid4()) endpoint_deployment_id = str(uuid.uuid4())
@ -172,10 +172,7 @@ def sagemaker_endpoint_create_api(raw_event, ctx):
creator_permissions = get_permissions_by_username(ddb_service, user_table, event.creator) creator_permissions = get_permissions_by_username(ddb_service, user_table, event.creator)
if 'sagemaker_endpoint' not in creator_permissions or \ if 'sagemaker_endpoint' not in creator_permissions or \
('all' not in creator_permissions['sagemaker_endpoint'] and 'create' not in creator_permissions['sagemaker_endpoint']): ('all' not in creator_permissions['sagemaker_endpoint'] and 'create' not in creator_permissions['sagemaker_endpoint']):
return { return bad_request(message=f"Creator {event.creator} has no permission to create Sagemaker")
'statusCode': 400,
'message': f"Creator {event.creator} has no permission to create Sagemaker",
}
endpoint_rows = ddb_service.scan(sagemaker_endpoint_table, filters=None) endpoint_rows = ddb_service.scan(sagemaker_endpoint_table, filters=None)
for endpoint_row in endpoint_rows: for endpoint_row in endpoint_rows:
@ -184,10 +181,8 @@ def sagemaker_endpoint_create_api(raw_event, ctx):
if endpoint.endpoint_status != EndpointStatus.DELETED.value and endpoint.status != 'deleted': if endpoint.endpoint_status != EndpointStatus.DELETED.value and endpoint.status != 'deleted':
for role in event.assign_to_roles: for role in event.assign_to_roles:
if role in endpoint.owner_group_or_role: if role in endpoint.owner_group_or_role:
return { return bad_request(
'statusCode': 400, message=f"role [{role}] has a valid endpoint already, not allow to have another one")
'message': f"role [{role}] has a valid endpoint already, not allow to have another one",
}
_create_sagemaker_model(sagemaker_model_name, image_url, model_data_url) _create_sagemaker_model(sagemaker_model_name, image_url, model_data_url)
@ -216,17 +211,13 @@ def sagemaker_endpoint_create_api(raw_event, ctx):
ddb_service.put_items(table=sagemaker_endpoint_table, entries=raw.__dict__) ddb_service.put_items(table=sagemaker_endpoint_table, entries=raw.__dict__)
logger.info(f"Successfully created endpoint deployment: {raw.__dict__}") logger.info(f"Successfully created endpoint deployment: {raw.__dict__}")
return { return ok(
'statusCode': 200, message=f"Endpoint deployment started: {sagemaker_endpoint_name}",
'message': f"Endpoint deployment started: {sagemaker_endpoint_name}", data=raw.__dict__
'data': raw.__dict__ )
}
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
return { return ok(message=str(e))
'statusCode': 200,
'message': str(e),
}
# lambda: handle sagemaker events # lambda: handle sagemaker events

View File

@ -10,7 +10,7 @@ import json
from _types import CheckPoint, CheckPointStatus, MultipartFileReq from _types import CheckPoint, CheckPointStatus, MultipartFileReq
from common.ddb_service.client import DynamoDbUtilsService from common.ddb_service.client import DynamoDbUtilsService
from common.response import ok from common.response import ok, bad_request, internal_server_error
from common_tools import get_base_checkpoint_s3_key, \ from common_tools import get_base_checkpoint_s3_key, \
batch_get_s3_multipart_signed_urls, complete_multipart_upload, multipart_upload_from_url batch_get_s3_multipart_signed_urls, complete_multipart_upload, multipart_upload_from_url
from multi_users._types import PARTITION_KEYS, Role from multi_users._types import PARTITION_KEYS, Role
@ -226,7 +226,7 @@ class CreateCheckPointEvent:
# POST /checkpoint # POST /checkpoint
def create_checkpoint_api(raw_event, context): def create_checkpoint_api(raw_event, context):
request_id = context.aws_request_id request_id = context.aws_request_id
event = CreateCheckPointEvent(**raw_event) event = CreateCheckPointEvent(**json.loads(raw_event['body']))
_type = event.checkpoint_type _type = event.checkpoint_type
headers = { headers = {
'Access-Control-Allow-Headers': 'Content-Type', 'Access-Control-Allow-Headers': 'Content-Type',
@ -262,11 +262,7 @@ def create_checkpoint_api(raw_event, context):
filenames_only.append(file.filename) filenames_only.append(file.filename)
if len(filenames_only) == 0: if len(filenames_only) == 0:
return { return bad_request(message='no checkpoint name (file names) detected', headers=headers)
'statusCode': 400,
'headers': headers,
'errorMsg': 'no checkpoint name (file names) detected'
}
user_roles = ['*'] user_roles = ['*']
creator_permissions = {} creator_permissions = {}
@ -276,11 +272,7 @@ def create_checkpoint_api(raw_event, context):
if 'checkpoint' not in creator_permissions or \ if 'checkpoint' not in creator_permissions or \
('all' not in creator_permissions['checkpoint'] and 'create' not in creator_permissions['checkpoint']): ('all' not in creator_permissions['checkpoint'] and 'create' not in creator_permissions['checkpoint']):
return { return bad_request(message='user has no permissions to create a model', headers=headers)
'statusCode': 400,
'headers': headers,
'error': f"user has no permissions to create a model"
}
checkpoint = CheckPoint( checkpoint = CheckPoint(
id=request_id, id=request_id,
@ -293,9 +285,7 @@ def create_checkpoint_api(raw_event, context):
allowed_roles_or_users=user_roles allowed_roles_or_users=user_roles
) )
ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__) ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__)
return { data = {
'statusCode': 200,
'headers': headers,
'checkpoint': { 'checkpoint': {
'id': request_id, 'id': request_id,
'type': _type, 'type': _type,
@ -305,25 +295,22 @@ def create_checkpoint_api(raw_event, context):
}, },
's3PresignUrl': multiparts_resp 's3PresignUrl': multiparts_resp
} }
return ok(data=data, headers=headers)
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
return { return internal_server_error(headers=headers, message=str(e))
'statusCode': 500,
'headers': headers,
'error': str(e)
}
@dataclass @dataclass
class UpdateCheckPointEvent: class UpdateCheckPointEvent:
checkpoint_id: str
status: str status: str
multi_parts_tags: Dict[str, Any] multi_parts_tags: Dict[str, Any]
# PUT /checkpoint # PUT /checkpoint
def update_checkpoint_api(raw_event, context): def update_checkpoint_api(raw_event, context):
event = UpdateCheckPointEvent(**raw_event) event = UpdateCheckPointEvent(**json.loads(raw_event['body']))
checkpoint_id = raw_event['pathParameters']['id']
headers = { headers = {
'Access-Control-Allow-Headers': 'Content-Type', 'Access-Control-Allow-Headers': 'Content-Type',
'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Origin': '*',
@ -331,14 +318,13 @@ def update_checkpoint_api(raw_event, context):
} }
try: try:
raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={ raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={
'id': event.checkpoint_id 'id': checkpoint_id
}) })
if raw_checkpoint is None or len(raw_checkpoint) == 0: if raw_checkpoint is None or len(raw_checkpoint) == 0:
return { return bad_request(
'statusCode': 500, message=f'checkpoint not found with id {checkpoint_id}',
'headers': headers, headers=headers
'error': f'checkpoint not found with id {event.checkpoint_id}' )
}
checkpoint = CheckPoint(**raw_checkpoint) checkpoint = CheckPoint(**raw_checkpoint)
new_status = CheckPointStatus[event.status] new_status = CheckPointStatus[event.status]
@ -352,9 +338,7 @@ def update_checkpoint_api(raw_event, context):
field_name='checkpoint_status', field_name='checkpoint_status',
value=new_status value=new_status
) )
return { data = {
'statusCode': 200,
'headers': headers,
'checkpoint': { 'checkpoint': {
'id': checkpoint.id, 'id': checkpoint.id,
'type': checkpoint.checkpoint_type, 'type': checkpoint.checkpoint_type,
@ -363,10 +347,7 @@ def update_checkpoint_api(raw_event, context):
'params': checkpoint.params 'params': checkpoint.params
} }
} }
return ok(data=data, headers=headers)
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
return { return internal_server_error(headers=headers, message=str(e))
'statusCode': 500,
'headers': headers,
'msg': str(e)
}

View File

@ -6,7 +6,7 @@ from typing import List, Optional
from common.ddb_service.client import DynamoDbUtilsService from common.ddb_service.client import DynamoDbUtilsService
from _types import User, PARTITION_KEYS, Role, Default_Role from _types import User, PARTITION_KEYS, Role, Default_Role
from common.response import ok from common.response import ok, bad_request
from roles_api import upsert_role from roles_api import upsert_role
from utils import KeyEncryptService, check_user_existence, get_permissions_by_username, get_user_by_username from utils import KeyEncryptService, check_user_existence, get_permissions_by_username, get_user_by_username
@ -31,8 +31,8 @@ class UpsertUserEvent:
# POST /user # POST /user
def upsert_user(raw_event, ctx): def upsert_user(raw_event, ctx):
print(raw_event) logger.info(json.dumps(raw_event))
event = UpsertUserEvent(**raw_event['body']) event = UpsertUserEvent(**json.loads(raw_event['body']))
if event.initial: if event.initial:
rolenames = [Default_Role] rolenames = [Default_Role]
@ -63,13 +63,16 @@ def upsert_user(raw_event, ctx):
aws_request_id: str aws_request_id: str
from_sd_local: bool from_sd_local: bool
resp = upsert_role(role_event, MockContext(aws_request_id='', from_sd_local=True)) # todo will be remove, not use api
create_role_event = {
'body': json.dumps(role_event)
}
resp = upsert_role(create_role_event, MockContext(aws_request_id='', from_sd_local=True))
if resp['statusCode'] != 200: if resp['statusCode'] != 200:
return resp return resp
return { data = {
'statusCode': 200,
'user': { 'user': {
'username': event.username, 'username': event.username,
'roles': [rolenames[0]] 'roles': [rolenames[0]]
@ -77,6 +80,8 @@ def upsert_user(raw_event, ctx):
'all_roles': rolenames, 'all_roles': rolenames,
} }
return ok(data=data)
check_permission_resp = _check_action_permission(event.creator, event.username) check_permission_resp = _check_action_permission(event.creator, event.username)
if check_permission_resp: if check_permission_resp:
return check_permission_resp return check_permission_resp
@ -98,19 +103,13 @@ def upsert_user(raw_event, ctx):
resource = permission_parts[0] resource = permission_parts[0]
action = permission_parts[1] action = permission_parts[1]
if 'all' not in creator_permissions[resource] and action not in creator_permissions[resource]: if 'all' not in creator_permissions[resource] and action not in creator_permissions[resource]:
return { return bad_request(message=f'creator has no permission to assign permission [{permission}] to others')
'statusCode': 400,
'errMsg': f'creator has no permission to assign permission [{permission}] to others'
}
roles_pool.append(role.sort_key) roles_pool.append(role.sort_key)
for role in event.roles: for role in event.roles:
if role not in roles_pool: if role not in roles_pool:
return { return bad_request(message=f'user roles "{role}" not exist')
'statusCode': 400,
'errMsg': f'user roles "{role}" not exist'
}
ddb_service.put_items(user_table, User( ddb_service.put_items(user_table, User(
kind=PARTITION_KEYS.user, kind=PARTITION_KEYS.user,
@ -120,8 +119,7 @@ def upsert_user(raw_event, ctx):
creator=event.creator, creator=event.creator,
).__dict__) ).__dict__)
return { data = {
'statusCode': 200,
'user': { 'user': {
'username': event.username, 'username': event.username,
'roles': event.roles, 'roles': event.roles,
@ -129,6 +127,8 @@ def upsert_user(raw_event, ctx):
} }
} }
return ok(data=data)
# DELETE /user/{username} # DELETE /user/{username}
def delete_user(event, ctx): def delete_user(event, ctx):
@ -155,10 +155,7 @@ def delete_user(event, ctx):
def _check_action_permission(creator_username, target_username): def _check_action_permission(creator_username, target_username):
# check if creator exist # check if creator exist
if check_user_existence(ddb_service=ddb_service, user_table=user_table, username=creator_username): if check_user_existence(ddb_service=ddb_service, user_table=user_table, username=creator_username):
return { return bad_request(message=f'creator {creator_username} not exist')
'statusCode': 400,
'errMsg': f'creator {creator_username} not exist'
}
target_user = get_user_by_username(ddb_service, user_table, target_username) target_user = get_user_by_username(ddb_service, user_table, target_username)
@ -166,28 +163,20 @@ def _check_action_permission(creator_username, target_username):
if 'user' not in creator_permissions or \ if 'user' not in creator_permissions or \
('all' not in creator_permissions['user'] and 'create' not in creator_permissions['user']): ('all' not in creator_permissions['user'] and 'create' not in creator_permissions['user']):
return { return bad_request(message=f'creator {creator_username} does not have permission to manage the user')
'statusCode': 400,
'errMsg': f'creator {creator_username} does not have permission to manage the user'
}
# if the creator have no permission (not created by creator), # if the creator have no permission (not created by creator),
# make sure the creator doesn't change the existed user (created by others) # make sure the creator doesn't change the existed user (created by others)
# and only user with 'user:all' can do update any users # and only user with 'user:all' can do update any users
if target_user and target_user.creator != creator_username and 'all' not in creator_permissions['user']: if target_user and target_user.creator != creator_username and 'all' not in creator_permissions['user']:
return { return bad_request(message=f'username {target_user.sort_key} has already exists, '
'statusCode': 400, f'creator {creator_username} does not have permissions to change it')
'errMsg': f'username {target_user.sort_key} has already exists, '
f'creator {creator_username} does not have permissions to change it'
}
if target_user and target_user.creator == creator_username and 'create' not in creator_permissions['user'] and 'all' \ if target_user and target_user.creator == creator_username and 'create' not in creator_permissions['user'] and 'all' \
not in creator_permissions['user']: not in creator_permissions['user']:
return { return bad_request(
'statusCode': 400, message=f'username {target_user.sort_key} has already exists, '
'errMsg': f'username {target_user.sort_key} has already exists, ' f'creator {creator_username} does not have permissions to change it')
f'creator {creator_username} does not have permissions to change it'
}
return None return None