stable-diffusion-aws-extension/infrastructure/src/sd-inference/sagemaker-endpoints-create.ts

286 lines
8.5 KiB
TypeScript

import { PythonFunction, PythonFunctionProps } from '@aws-cdk/aws-lambda-python-alpha';
import { Aws, Duration } from 'aws-cdk-lib';
import {
JsonSchemaType,
JsonSchemaVersion,
LambdaIntegration,
Model,
Resource,
IAuthorizer,
RequestValidator,
} from 'aws-cdk-lib/aws-apigateway';
import { MethodOptions } from 'aws-cdk-lib/aws-apigateway/lib/method';
import { Table } from 'aws-cdk-lib/aws-dynamodb';
import { Effect, PolicyStatement, Role } from 'aws-cdk-lib/aws-iam';
import { Architecture, LayerVersion, Runtime } from 'aws-cdk-lib/aws-lambda';
import { Bucket } from 'aws-cdk-lib/aws-s3';
import { Topic } from 'aws-cdk-lib/aws-sns';
import { Construct } from 'constructs';
import { LAMBDA_START_DEPLOY_ROLE_NAME } from '../shared/deploy-role';
export interface CreateSagemakerEndpointsApiProps {
router: Resource;
httpMethod: string;
endpointDeploymentTable: Table;
multiUserTable: Table;
inferenceJobTable: Table;
srcRoot: string;
inferenceECRUrl: string;
commonLayer: LayerVersion;
authorizer: IAuthorizer;
s3Bucket: Bucket;
userNotifySNS: Topic;
inferenceResultTopic: Topic;
inferenceResultErrorTopic: Topic;
}
export class CreateSagemakerEndpointsApi {
private readonly src;
private readonly router: Resource;
private readonly httpMethod: string;
private readonly scope: Construct;
private readonly endpointDeploymentTable: Table;
private readonly multiUserTable: Table;
private readonly inferenceJobTable: Table;
private readonly layer: LayerVersion;
private readonly baseId: string;
private readonly inferenceECRUrl: string;
private readonly authorizer: IAuthorizer;
private readonly s3Bucket: Bucket;
private readonly userNotifySNS: Topic;
private readonly inferenceResultTopic: Topic;
private readonly inferenceResultErrorTopic: Topic;
constructor(scope: Construct, id: string, props: CreateSagemakerEndpointsApiProps) {
this.scope = scope;
this.baseId = id;
this.router = props.router;
this.httpMethod = props.httpMethod;
this.endpointDeploymentTable = props.endpointDeploymentTable;
this.inferenceJobTable = props.inferenceJobTable;
this.multiUserTable = props.multiUserTable;
this.authorizer = props.authorizer;
this.src = props.srcRoot;
this.layer = props.commonLayer;
this.s3Bucket = props.s3Bucket;
this.inferenceECRUrl = props.inferenceECRUrl;
this.userNotifySNS = props.userNotifySNS;
this.inferenceResultTopic = props.inferenceResultTopic;
this.inferenceResultErrorTopic = props.inferenceResultErrorTopic;
console.log(this.userNotifySNS);
this.createEndpointsApi();
}
private iamRole(): Role {
const snsStatement = new PolicyStatement({
actions: [
'sns:Publish',
'sns:ListSubscriptionsByTopic',
'sns:ListTopics',
],
resources: [
this.userNotifySNS.topicArn,
this.inferenceResultTopic.topicArn,
this.inferenceResultErrorTopic.topicArn,
],
});
const s3Statement = new PolicyStatement({
actions: [
's3:Get*',
's3:List*',
's3:PutObject',
's3:GetObject',
],
resources: [
this.s3Bucket.bucketArn,
`${this.s3Bucket.bucketArn}/*`,
'arn:aws:s3:::*sagemaker*',
],
});
const endpointStatement = new PolicyStatement({
actions: [
'sagemaker:InvokeEndpoint',
'sagemaker:CreateModel',
'sagemaker:CreateEndpoint',
'sagemaker:CreateEndpointConfig',
'sagemaker:DescribeEndpoint',
'sagemaker:InvokeEndpointAsync',
'ecr:GetAuthorizationToken',
'ecr:BatchCheckLayerAvailability',
'ecr:GetDownloadUrlForLayer',
'ecr:GetRepositoryPolicy',
'ecr:DescribeRepositories',
'ecr:ListImages',
'ecr:DescribeImages',
'ecr:BatchGetImage',
'ecr:InitiateLayerUpload',
'ecr:UploadLayerPart',
'ecr:CompleteLayerUpload',
'ecr:PutImage',
'cloudwatch:PutMetricAlarm',
'cloudwatch:PutMetricData',
'sagemaker:DescribeEndpointConfig',
'cloudwatch:DeleteAlarms',
'cloudwatch:DescribeAlarms',
'sagemaker:UpdateEndpointWeightsAndCapacities',
'iam:CreateServiceLinkedRole',
'iam:PassRole',
'sts:AssumeRole',
],
resources: ['*'],
});
const ddbStatement = new PolicyStatement({
actions: [
'dynamodb:Query',
'dynamodb:GetItem',
'dynamodb:PutItem',
'dynamodb:DeleteItem',
'dynamodb:UpdateItem',
'dynamodb:Describe*',
'dynamodb:List*',
'dynamodb:Scan',
],
resources: [
this.endpointDeploymentTable.tableArn,
this.multiUserTable.tableArn,
this.inferenceJobTable.tableArn,
],
});
const lambdaStartDeployRole = <Role>Role.fromRoleName(
this.scope,
'createSagemakerEpRole',
LAMBDA_START_DEPLOY_ROLE_NAME,
);
const logStatement = new PolicyStatement({
effect: Effect.ALLOW,
actions: [
'logs:CreateLogGroup',
'logs:CreateLogStream',
'logs:PutLogEvents',
],
resources: [`arn:${Aws.PARTITION}:logs:${Aws.REGION}:${Aws.ACCOUNT_ID}:log-group:*:*`],
});
const passStartDeployRole = new PolicyStatement({
actions: [
'iam:PassRole',
],
resources: [`arn:${Aws.PARTITION}:iam::${Aws.ACCOUNT_ID}:role/LambdaStartDeployRole`],
});
lambdaStartDeployRole.addToPolicy(snsStatement);
lambdaStartDeployRole.addToPolicy(s3Statement);
lambdaStartDeployRole.addToPolicy(endpointStatement);
lambdaStartDeployRole.addToPolicy(ddbStatement);
lambdaStartDeployRole.addToPolicy(logStatement);
lambdaStartDeployRole.addToPolicy(passStartDeployRole);
return lambdaStartDeployRole;
}
private createEndpointsApi() {
const role = this.iamRole();
const lambdaFunction = new PythonFunction(this.scope, `${this.baseId}-lambda`, <PythonFunctionProps>{
entry: `${this.src}/inference_v2`,
architecture: Architecture.X86_64,
runtime: Runtime.PYTHON_3_9,
index: 'sagemaker_endpoint_api.py',
handler: 'sagemaker_endpoint_create_api',
timeout: Duration.seconds(900),
role: role,
memorySize: 1024,
environment: {
DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME: this.endpointDeploymentTable.tableName,
MULTI_USER_TABLE: this.multiUserTable.tableName,
S3_BUCKET_NAME: this.s3Bucket.bucketName,
INFERENCE_ECR_IMAGE_URL: this.inferenceECRUrl,
SNS_INFERENCE_SUCCESS: this.inferenceResultTopic.topicArn,
SNS_INFERENCE_ERROR: this.inferenceResultErrorTopic.topicArn,
EXECUTION_ROLE_ARN: role.roleArn,
},
layers: [this.layer],
});
const model = new Model(this.scope, `${this.baseId}-model`, {
restApi: this.router.api,
contentType: 'application/json',
modelName: this.baseId,
description: `${this.baseId} Request Model`,
schema: {
schema: JsonSchemaVersion.DRAFT4,
title: this.baseId,
type: JsonSchemaType.OBJECT,
properties: {
endpoint_name: {
type: JsonSchemaType.STRING,
minLength: 0,
maxLength: 20,
},
instance_type: {
type: JsonSchemaType.STRING,
},
initial_instance_count: {
type: JsonSchemaType.STRING,
},
autoscaling_enabled: {
type: JsonSchemaType.BOOLEAN,
},
assign_to_roles: {
type: JsonSchemaType.ARRAY,
items: {
type: JsonSchemaType.STRING,
},
minItems: 1,
maxItems: 10,
},
creator: {
type: JsonSchemaType.STRING,
},
},
required: [
'instance_type',
'initial_instance_count',
'autoscaling_enabled',
'assign_to_roles',
'creator',
],
},
});
const integration = new LambdaIntegration(
lambdaFunction,
{
proxy: true,
},
);
const requestValidator = new RequestValidator(this.scope, `${this.baseId}-validator`, {
restApi: this.router.api,
requestValidatorName: this.baseId,
validateRequestBody: true,
validateRequestParameters: false,
});
this.router.addMethod(this.httpMethod, integration, <MethodOptions>{
apiKeyRequired: true,
authorizer: this.authorizer,
requestValidator,
requestModels: {
'application/json': model,
},
});
}
}