diff --git a/infrastructure/src/events/trainings-event.ts b/infrastructure/src/events/trainings-event.ts index 007c5311..df4f3569 100644 --- a/infrastructure/src/events/trainings-event.ts +++ b/infrastructure/src/events/trainings-event.ts @@ -99,14 +99,6 @@ export class SagemakerTrainingEvents { resources: [`arn:${Aws.PARTITION}:sagemaker:${Aws.REGION}:${Aws.ACCOUNT_ID}:training-job/*`], })); - newRole.addToPolicy(new aws_iam.PolicyStatement({ - effect: Effect.ALLOW, - actions: [ - 'kms:*', - ], - resources: ['*'], - })); - newRole.addToPolicy(new aws_iam.PolicyStatement({ effect: Effect.ALLOW, actions: [ @@ -157,6 +149,11 @@ export class SagemakerTrainingEvents { layers: [this.layer], }); + lambdaFunction.addToRolePolicy(new PolicyStatement({ + actions: ['sns:Publish'], + resources: [this.userSnsTopic.topicArn], + })); + const rule = new Rule(this.scope, `${this.baseId}-rule`, { eventPattern: { source: ['aws.sagemaker'], diff --git a/infrastructure/src/main.ts b/infrastructure/src/main.ts index ace178e8..41c531e4 100644 --- a/infrastructure/src/main.ts +++ b/infrastructure/src/main.ts @@ -166,8 +166,6 @@ export class Middleware extends Stack { routers: restApi.routers, s3Bucket: s3Bucket, snsTopic: snsTopics.snsTopic, - createModelFailureTopic: snsTopics.createModelFailureTopic, - createModelSuccessTopic: snsTopics.createModelSuccessTopic, logLevel, resourceProvider, accountId, diff --git a/infrastructure/src/shared/resource-provider-on-event.ts b/infrastructure/src/shared/resource-provider-on-event.ts index d7374ec1..43880219 100644 --- a/infrastructure/src/shared/resource-provider-on-event.ts +++ b/infrastructure/src/shared/resource-provider-on-event.ts @@ -434,8 +434,6 @@ async function createTopics() { const list = [ 'ReceiveSageMakerInferenceSuccess', 'ReceiveSageMakerInferenceError', - 'successCreateModel', - 'failureCreateModel', 'StableDiffusionSnsUserTopic', ]; diff --git a/infrastructure/src/shared/sns-topics.ts b/infrastructure/src/shared/sns-topics.ts index 879ada2a..5357ed4e 100644 --- a/infrastructure/src/shared/sns-topics.ts +++ b/infrastructure/src/shared/sns-topics.ts @@ -7,8 +7,6 @@ import { Construct } from 'constructs'; export class SnsTopics { public readonly snsTopic: Topic; - public readonly createModelSuccessTopic: Topic; - public readonly createModelFailureTopic: Topic; public readonly inferenceResultTopic: Topic; public readonly inferenceResultErrorTopic: Topic; private readonly scope: Construct; @@ -25,8 +23,6 @@ export class SnsTopics { this.inferenceResultTopic = this.createOrImportTopic('ReceiveSageMakerInferenceSuccess'); this.inferenceResultErrorTopic = this.createOrImportTopic('ReceiveSageMakerInferenceError'); - this.createModelSuccessTopic = this.createOrImportTopic('successCreateModel'); - this.createModelFailureTopic = this.createOrImportTopic('failureCreateModel'); } private createOrImportTopic(topicName: string): Topic { diff --git a/infrastructure/src/shared/train-deploy.ts b/infrastructure/src/shared/train-deploy.ts index ecee29a7..c67a3a07 100644 --- a/infrastructure/src/shared/train-deploy.ts +++ b/infrastructure/src/shared/train-deploy.ts @@ -21,10 +21,7 @@ import { GetTrainingJobApi } from '../api/trainings/get-training-job'; import { ListTrainingJobsApi } from '../api/trainings/list-training-jobs'; import { SagemakerTrainingEvents } from '../events/trainings-event'; -// ckpt -> create_model -> model -> training -> ckpt -> inference export interface TrainDeployProps extends StackProps { - createModelSuccessTopic: aws_sns.Topic; - createModelFailureTopic: aws_sns.Topic; database: Database; routers: { [key: string]: Resource }; s3Bucket: aws_s3.Bucket; diff --git a/middleware_api/lambda/common/util.py b/middleware_api/lambda/common/util.py index 10deeea2..592d2ec1 100644 --- a/middleware_api/lambda/common/util.py +++ b/middleware_api/lambda/common/util.py @@ -9,6 +9,7 @@ s3 = boto3.client('s3') logger = logging.getLogger(__name__) logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR) +sns_client = boto3.client('sns') def get_multi_query_params(event, param_name: str, default=None): @@ -42,8 +43,7 @@ def query_data(data, paths): def publish_msg(topic_arn, msg, subject): - client = boto3.client('sns') - client.publish( + sns_client.publish( TopicArn=topic_arn, Message=str(msg), Subject=subject diff --git a/update_scripts/retained_sns b/update_scripts/retained_sns index c774b212..53d507fb 100644 --- a/update_scripts/retained_sns +++ b/update_scripts/retained_sns @@ -1,5 +1,3 @@ ReceiveSageMakerInferenceError ReceiveSageMakerInferenceSuccess StableDiffusionSnsUserTopic -failureCreateModel -successCreateModel \ No newline at end of file