stable-diffusion-aws-extension/middleware_api/lambda/trainings/training_event.py

146 lines
4.2 KiB
Python

import json
import logging
import os
import boto3
from common.ddb_service.client import DynamoDbUtilsService
from common.response import ok, not_found
from common.util import publish_msg
from libs.common_tools import split_s3_path
from libs.data_types import TrainJob, TrainJobStatus, CheckPoint, CheckPointStatus
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
train_table = os.environ.get('TRAINING_JOB_TABLE')
checkpoint_table = os.environ.get('CHECKPOINT_TABLE')
user_topic_arn = os.environ.get('USER_EMAIL_TOPIC_ARN')
sagemaker = boto3.client('sagemaker')
s3 = boto3.client('s3')
ddb_service = DynamoDbUtilsService(logger=logger)
def handler(event, ctx):
logger.info(json.dumps(event))
train_job_name = event['detail']['TrainingJobName']
rows = ddb_service.scan(train_table, filters={
'sagemaker_train_name': train_job_name,
})
logger.info(rows)
if not rows or len(rows) == 0:
return not_found(message=f'training job {train_job_name} is not found')
training_job = TrainJob(**(ddb_service.deserialize(rows[0])))
logger.info(training_job)
check_status(training_job)
return ok()
# sfn
def check_status(training_job: TrainJob):
resp = sagemaker.describe_training_job(
TrainingJobName=training_job.sagemaker_train_name
)
logger.info(resp)
training_job_status = resp['TrainingJobStatus']
secondary_status = resp['SecondaryStatus']
ddb_service.update_item(
table=train_table,
key={'id': training_job.id},
field_name='job_status',
value=secondary_status
)
if training_job_status == 'Failed' or training_job_status == 'Stopped':
if 'FailureReason' in resp:
err_msg = resp['FailureReason']
training_job.params['resp'] = {
'status': 'Failed',
'error_msg': err_msg,
'raw_resp': resp
}
if training_job_status == 'Completed':
notify_user(training_job)
# todo: update checkpoints
raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={
'id': training_job.checkpoint_id
})
if raw_checkpoint is None or len(raw_checkpoint) == 0:
# todo: or create new one
return 'failed because no checkpoint, not normal'
checkpoint = CheckPoint(**raw_checkpoint)
checkpoint.checkpoint_status = CheckPointStatus.Active
bucket, key = split_s3_path(checkpoint.s3_location)
s3_resp = s3.list_objects(
Bucket=bucket,
Prefix=key,
)
checkpoint.checkpoint_names = []
if 'Contents' in s3_resp and len(s3_resp['Contents']) > 0:
for obj in s3_resp['Contents']:
checkpoint_name = obj['Key'].replace(f'{key}/', "")
checkpoint.checkpoint_names.append(checkpoint_name)
else:
checkpoint.checkpoint_status = CheckPointStatus.Initial
ddb_service.update_item(
table=train_table,
key={'id': training_job.id},
field_name='job_status',
value=TrainJobStatus.Fail
)
ddb_service.update_item(
table=checkpoint_table,
key={
'id': checkpoint.id
},
field_name='checkpoint_status',
value=checkpoint.checkpoint_status.value
)
ddb_service.update_item(
table=checkpoint_table,
key={
'id': checkpoint.id
},
field_name='checkpoint_names',
value=checkpoint.checkpoint_names
)
training_job.params['resp'] = {
'raw_resp': resp
}
ddb_service.update_item(
table=train_table,
key={'id': training_job.id},
field_name='params',
value=training_job.params
)
return
def notify_user(train_job: TrainJob):
publish_msg(
topic_arn=user_topic_arn,
subject=f'Create Model Job {train_job.sagemaker_train_name} {train_job.job_status}',
msg=f'to be done with resp: \n {train_job.job_status}'
) # todo: find out msg
return 'job completed'