286 lines
8.5 KiB
TypeScript
286 lines
8.5 KiB
TypeScript
import { PythonFunction, PythonFunctionProps } from '@aws-cdk/aws-lambda-python-alpha';
|
|
import { Aws, Duration } from 'aws-cdk-lib';
|
|
import {
|
|
JsonSchemaType,
|
|
JsonSchemaVersion,
|
|
LambdaIntegration,
|
|
Model,
|
|
Resource,
|
|
IAuthorizer,
|
|
RequestValidator,
|
|
} from 'aws-cdk-lib/aws-apigateway';
|
|
import { MethodOptions } from 'aws-cdk-lib/aws-apigateway/lib/method';
|
|
import { Table } from 'aws-cdk-lib/aws-dynamodb';
|
|
import { Effect, PolicyStatement, Role } from 'aws-cdk-lib/aws-iam';
|
|
import { Architecture, LayerVersion, Runtime } from 'aws-cdk-lib/aws-lambda';
|
|
import { Bucket } from 'aws-cdk-lib/aws-s3';
|
|
import { Topic } from 'aws-cdk-lib/aws-sns';
|
|
import { Construct } from 'constructs';
|
|
import { LAMBDA_START_DEPLOY_ROLE_NAME } from '../shared/deploy-role';
|
|
|
|
|
|
export interface CreateSagemakerEndpointsApiProps {
|
|
router: Resource;
|
|
httpMethod: string;
|
|
endpointDeploymentTable: Table;
|
|
multiUserTable: Table;
|
|
inferenceJobTable: Table;
|
|
srcRoot: string;
|
|
inferenceECRUrl: string;
|
|
commonLayer: LayerVersion;
|
|
authorizer: IAuthorizer;
|
|
s3Bucket: Bucket;
|
|
userNotifySNS: Topic;
|
|
inferenceResultTopic: Topic;
|
|
inferenceResultErrorTopic: Topic;
|
|
}
|
|
|
|
export class CreateSagemakerEndpointsApi {
|
|
private readonly src;
|
|
private readonly router: Resource;
|
|
private readonly httpMethod: string;
|
|
private readonly scope: Construct;
|
|
private readonly endpointDeploymentTable: Table;
|
|
private readonly multiUserTable: Table;
|
|
private readonly inferenceJobTable: Table;
|
|
private readonly layer: LayerVersion;
|
|
private readonly baseId: string;
|
|
private readonly inferenceECRUrl: string;
|
|
private readonly authorizer: IAuthorizer;
|
|
private readonly s3Bucket: Bucket;
|
|
private readonly userNotifySNS: Topic;
|
|
private readonly inferenceResultTopic: Topic;
|
|
private readonly inferenceResultErrorTopic: Topic;
|
|
|
|
constructor(scope: Construct, id: string, props: CreateSagemakerEndpointsApiProps) {
|
|
this.scope = scope;
|
|
this.baseId = id;
|
|
this.router = props.router;
|
|
this.httpMethod = props.httpMethod;
|
|
this.endpointDeploymentTable = props.endpointDeploymentTable;
|
|
this.inferenceJobTable = props.inferenceJobTable;
|
|
this.multiUserTable = props.multiUserTable;
|
|
this.authorizer = props.authorizer;
|
|
this.src = props.srcRoot;
|
|
this.layer = props.commonLayer;
|
|
this.s3Bucket = props.s3Bucket;
|
|
this.inferenceECRUrl = props.inferenceECRUrl;
|
|
this.userNotifySNS = props.userNotifySNS;
|
|
this.inferenceResultTopic = props.inferenceResultTopic;
|
|
this.inferenceResultErrorTopic = props.inferenceResultErrorTopic;
|
|
|
|
console.log(this.userNotifySNS);
|
|
|
|
this.createEndpointsApi();
|
|
}
|
|
|
|
private iamRole(): Role {
|
|
|
|
const snsStatement = new PolicyStatement({
|
|
actions: [
|
|
'sns:Publish',
|
|
'sns:ListSubscriptionsByTopic',
|
|
'sns:ListTopics',
|
|
],
|
|
resources: [
|
|
this.userNotifySNS.topicArn,
|
|
this.inferenceResultTopic.topicArn,
|
|
this.inferenceResultErrorTopic.topicArn,
|
|
],
|
|
});
|
|
|
|
const s3Statement = new PolicyStatement({
|
|
actions: [
|
|
's3:Get*',
|
|
's3:List*',
|
|
's3:PutObject',
|
|
's3:GetObject',
|
|
],
|
|
resources: [
|
|
this.s3Bucket.bucketArn,
|
|
`${this.s3Bucket.bucketArn}/*`,
|
|
'arn:aws:s3:::*sagemaker*',
|
|
],
|
|
});
|
|
|
|
const endpointStatement = new PolicyStatement({
|
|
actions: [
|
|
'sagemaker:InvokeEndpoint',
|
|
'sagemaker:CreateModel',
|
|
'sagemaker:CreateEndpoint',
|
|
'sagemaker:CreateEndpointConfig',
|
|
'sagemaker:DescribeEndpoint',
|
|
'sagemaker:InvokeEndpointAsync',
|
|
'ecr:GetAuthorizationToken',
|
|
'ecr:BatchCheckLayerAvailability',
|
|
'ecr:GetDownloadUrlForLayer',
|
|
'ecr:GetRepositoryPolicy',
|
|
'ecr:DescribeRepositories',
|
|
'ecr:ListImages',
|
|
'ecr:DescribeImages',
|
|
'ecr:BatchGetImage',
|
|
'ecr:InitiateLayerUpload',
|
|
'ecr:UploadLayerPart',
|
|
'ecr:CompleteLayerUpload',
|
|
'ecr:PutImage',
|
|
'cloudwatch:PutMetricAlarm',
|
|
'cloudwatch:PutMetricData',
|
|
'sagemaker:DescribeEndpointConfig',
|
|
'cloudwatch:DeleteAlarms',
|
|
'cloudwatch:DescribeAlarms',
|
|
'sagemaker:UpdateEndpointWeightsAndCapacities',
|
|
'iam:CreateServiceLinkedRole',
|
|
'iam:PassRole',
|
|
'sts:AssumeRole',
|
|
],
|
|
resources: ['*'],
|
|
});
|
|
|
|
const ddbStatement = new PolicyStatement({
|
|
actions: [
|
|
'dynamodb:Query',
|
|
'dynamodb:GetItem',
|
|
'dynamodb:PutItem',
|
|
'dynamodb:DeleteItem',
|
|
'dynamodb:UpdateItem',
|
|
'dynamodb:Describe*',
|
|
'dynamodb:List*',
|
|
'dynamodb:Scan',
|
|
],
|
|
resources: [
|
|
this.endpointDeploymentTable.tableArn,
|
|
this.multiUserTable.tableArn,
|
|
this.inferenceJobTable.tableArn,
|
|
],
|
|
});
|
|
|
|
const lambdaStartDeployRole = <Role>Role.fromRoleName(
|
|
this.scope,
|
|
'createSagemakerEpRole',
|
|
LAMBDA_START_DEPLOY_ROLE_NAME,
|
|
);
|
|
|
|
const logStatement = new PolicyStatement({
|
|
effect: Effect.ALLOW,
|
|
actions: [
|
|
'logs:CreateLogGroup',
|
|
'logs:CreateLogStream',
|
|
'logs:PutLogEvents',
|
|
],
|
|
resources: [`arn:${Aws.PARTITION}:logs:${Aws.REGION}:${Aws.ACCOUNT_ID}:log-group:*:*`],
|
|
});
|
|
|
|
const passStartDeployRole = new PolicyStatement({
|
|
actions: [
|
|
'iam:PassRole',
|
|
],
|
|
resources: [`arn:${Aws.PARTITION}:iam::${Aws.ACCOUNT_ID}:role/LambdaStartDeployRole`],
|
|
});
|
|
|
|
lambdaStartDeployRole.addToPolicy(snsStatement);
|
|
lambdaStartDeployRole.addToPolicy(s3Statement);
|
|
lambdaStartDeployRole.addToPolicy(endpointStatement);
|
|
lambdaStartDeployRole.addToPolicy(ddbStatement);
|
|
lambdaStartDeployRole.addToPolicy(logStatement);
|
|
lambdaStartDeployRole.addToPolicy(passStartDeployRole);
|
|
|
|
return lambdaStartDeployRole;
|
|
}
|
|
|
|
private createEndpointsApi() {
|
|
|
|
const role = this.iamRole();
|
|
|
|
const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, <PythonFunctionProps>{
|
|
entry: `${this.src}/inference_v2`,
|
|
architecture: Architecture.X86_64,
|
|
runtime: Runtime.PYTHON_3_9,
|
|
index: 'sagemaker_endpoint_api.py',
|
|
handler: 'sagemaker_endpoint_create_api',
|
|
timeout: Duration.seconds(900),
|
|
role: role,
|
|
memorySize: 1024,
|
|
environment: {
|
|
DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME: this.endpointDeploymentTable.tableName,
|
|
MULTI_USER_TABLE: this.multiUserTable.tableName,
|
|
S3_BUCKET_NAME: this.s3Bucket.bucketName,
|
|
INFERENCE_ECR_IMAGE_URL: this.inferenceECRUrl,
|
|
SNS_INFERENCE_SUCCESS: this.inferenceResultTopic.topicArn,
|
|
SNS_INFERENCE_ERROR: this.inferenceResultErrorTopic.topicArn,
|
|
EXECUTION_ROLE_ARN: role.roleArn,
|
|
},
|
|
layers: [this.layer],
|
|
});
|
|
|
|
const model = new Model(this.scope, `${this.baseId}-model`, {
|
|
restApi: this.router.api,
|
|
contentType: 'application/json',
|
|
modelName: this.baseId,
|
|
description: `${this.baseId} Request Model`,
|
|
schema: {
|
|
schema: JsonSchemaVersion.DRAFT4,
|
|
title: this.baseId,
|
|
type: JsonSchemaType.OBJECT,
|
|
properties: {
|
|
endpoint_name: {
|
|
type: JsonSchemaType.STRING,
|
|
minLength: 0,
|
|
maxLength: 20,
|
|
},
|
|
instance_type: {
|
|
type: JsonSchemaType.STRING,
|
|
},
|
|
initial_instance_count: {
|
|
type: JsonSchemaType.STRING,
|
|
},
|
|
autoscaling_enabled: {
|
|
type: JsonSchemaType.BOOLEAN,
|
|
},
|
|
assign_to_roles: {
|
|
type: JsonSchemaType.ARRAY,
|
|
items: {
|
|
type: JsonSchemaType.STRING,
|
|
},
|
|
minItems: 1,
|
|
maxItems: 10,
|
|
},
|
|
creator: {
|
|
type: JsonSchemaType.STRING,
|
|
},
|
|
},
|
|
required: [
|
|
'instance_type',
|
|
'initial_instance_count',
|
|
'autoscaling_enabled',
|
|
'assign_to_roles',
|
|
'creator',
|
|
],
|
|
},
|
|
});
|
|
|
|
const integration = new LambdaIntegration(
|
|
lambdaFunction,
|
|
{
|
|
proxy: true,
|
|
},
|
|
);
|
|
|
|
const requestValidator = new RequestValidator(this.scope, `${this.baseId}-validator`, {
|
|
restApi: this.router.api,
|
|
requestValidatorName: this.baseId,
|
|
validateRequestBody: true,
|
|
validateRequestParameters: false,
|
|
});
|
|
|
|
this.router.addMethod(this.httpMethod, integration, <MethodOptions>{
|
|
apiKeyRequired: true,
|
|
authorizer: this.authorizer,
|
|
requestValidator,
|
|
requestModels: {
|
|
'application/json': model,
|
|
},
|
|
});
|
|
|
|
}
|
|
}
|