177 lines
5.3 KiB
Python
177 lines
5.3 KiB
Python
import datetime
|
|
import json
|
|
import logging
|
|
import os
|
|
import uuid
|
|
|
|
import boto3
|
|
from aws_lambda_powertools import Tracer
|
|
|
|
from common import const
|
|
from common.ddb_service.client import DynamoDbUtilsService
|
|
from common.response import ok, not_found
|
|
from common.util import publish_msg, generate_presigned_url_for_key, record_seconds_metrics
|
|
from inferences.inference_libs import update_table_by_pk
|
|
from libs.data_types import TrainJob, TrainJobStatus, CheckPoint, CheckPointStatus
|
|
from libs.enums import ServiceType
|
|
|
|
tracer = Tracer()
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
|
|
|
|
train_table = os.environ.get('TRAINING_JOB_TABLE')
|
|
checkpoint_table = os.environ.get('CHECKPOINT_TABLE')
|
|
user_topic_arn = os.environ.get('USER_EMAIL_TOPIC_ARN')
|
|
bucket_name = os.environ.get("S3_BUCKET_NAME")
|
|
sagemaker = boto3.client('sagemaker')
|
|
s3 = boto3.client('s3')
|
|
|
|
ddb_service = DynamoDbUtilsService(logger=logger)
|
|
|
|
|
|
@tracer.capture_lambda_handler
|
|
def handler(event, ctx):
|
|
logger.info(json.dumps(event))
|
|
train_job_name = event['detail']['TrainingJobName']
|
|
|
|
rows = ddb_service.scan(train_table, filters={
|
|
'sagemaker_train_name': train_job_name,
|
|
})
|
|
|
|
logger.info(rows)
|
|
|
|
if not rows or len(rows) == 0:
|
|
return not_found(message=f'training job {train_job_name} is not found')
|
|
|
|
training_job = TrainJob(**(ddb_service.deserialize(rows[0])))
|
|
|
|
logger.info(training_job)
|
|
|
|
check_status(training_job)
|
|
|
|
return ok()
|
|
|
|
|
|
# sfn
|
|
def check_status(training_job: TrainJob):
|
|
resp = sagemaker.describe_training_job(
|
|
TrainingJobName=training_job.sagemaker_train_name
|
|
)
|
|
|
|
logger.info(resp)
|
|
|
|
training_job_status = resp['TrainingJobStatus']
|
|
secondary_status = resp['SecondaryStatus']
|
|
|
|
update_table_by_pk(table_name=train_table, pk='id', id=training_job.id, key='job_status', value=secondary_status)
|
|
|
|
if training_job_status == 'Failed' or training_job_status == 'Stopped':
|
|
if 'FailureReason' in resp:
|
|
err_msg = resp['FailureReason']
|
|
training_job.params['resp'] = {
|
|
'status': 'Failed',
|
|
'error_msg': err_msg,
|
|
'raw_resp': resp
|
|
}
|
|
|
|
if training_job_status == 'Completed':
|
|
|
|
float_timestamp = float(training_job.timestamp)
|
|
timestamp = datetime.datetime.fromtimestamp(float_timestamp).isoformat()
|
|
record_seconds_metrics(start_time=timestamp, metric_name='TrainingLatency', service=ServiceType.SD.value)
|
|
|
|
try:
|
|
notify_user(training_job)
|
|
except Exception as e:
|
|
logger.error(e, exc_info=True)
|
|
|
|
prefix = f"Stable-diffusion/checkpoint/custom/{training_job.id}"
|
|
s3_resp = s3.list_objects(
|
|
Bucket=bucket_name,
|
|
Prefix=prefix,
|
|
)
|
|
|
|
if 'Contents' in s3_resp and len(s3_resp['Contents']) > 0:
|
|
for obj in s3_resp['Contents']:
|
|
checkpoint_name = obj['Key'].replace(f'{prefix}/', "")
|
|
logger.info(f'checkpoint_name: {checkpoint_name}')
|
|
insert_ckpt(checkpoint_name, training_job)
|
|
|
|
logs = get_logs(training_job.id)
|
|
update_table_by_pk(table_name=train_table, pk='id', id=training_job.id, key='logs', value=logs)
|
|
|
|
else:
|
|
update_table_by_pk(table_name=train_table, pk='id', id=training_job.id, key='job_status', value=TrainJobStatus.Fail)
|
|
|
|
training_job.params['resp'] = {
|
|
'raw_resp': resp
|
|
}
|
|
|
|
ddb_service.update_item(
|
|
table=train_table,
|
|
key={'id': training_job.id},
|
|
field_name='params',
|
|
value=training_job.params
|
|
)
|
|
|
|
return
|
|
|
|
|
|
def get_logs(job_id: str):
|
|
prefix = f"kohya/train/{job_id}/logs/"
|
|
s3_resp = s3.list_objects(
|
|
Bucket=bucket_name,
|
|
Prefix=prefix,
|
|
)
|
|
|
|
logs = []
|
|
if 'Contents' in s3_resp and len(s3_resp['Contents']) > 0:
|
|
for obj in s3_resp['Contents']:
|
|
logs.append(obj['Key'].replace(prefix, ''))
|
|
|
|
return logs
|
|
|
|
|
|
def get_logs_presign(job_id, logs):
|
|
if len(logs) == 0:
|
|
return []
|
|
|
|
presign_logs = []
|
|
for filename in logs:
|
|
presign_logs.append({
|
|
'filename': filename,
|
|
'url': generate_presigned_url_for_key(f"kohya/train/{job_id}/logs/{filename}")
|
|
})
|
|
|
|
return presign_logs
|
|
|
|
|
|
def insert_ckpt(output_name, job: TrainJob):
|
|
raw_ckpts = ddb_service.scan(checkpoint_table)
|
|
for r in raw_ckpts:
|
|
ckpt = CheckPoint(**(ddb_service.deserialize(r)))
|
|
if output_name in ckpt.checkpoint_names:
|
|
return
|
|
|
|
checkpoint = CheckPoint(
|
|
id=str(uuid.uuid4()),
|
|
checkpoint_type=const.CheckPointType.LORA,
|
|
checkpoint_names=[output_name],
|
|
s3_location=f"s3://{bucket_name}/Stable-diffusion/checkpoint/custom/{job.id}",
|
|
checkpoint_status=CheckPointStatus.Active,
|
|
timestamp=datetime.datetime.now().timestamp(),
|
|
allowed_roles_or_users=job.allowed_roles_or_users,
|
|
)
|
|
|
|
ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__)
|
|
|
|
|
|
def notify_user(train_job: TrainJob):
|
|
publish_msg(
|
|
topic_arn=user_topic_arn,
|
|
subject=f'Create Model Job {train_job.sagemaker_train_name} {train_job.job_status}',
|
|
msg=f'to be done with resp: \n {train_job.job_status}'
|
|
) # todo: find out msg
|
|
|
|
return 'job completed'
|