import datetime import json import logging import os import uuid import boto3 from aws_lambda_powertools import Tracer from common import const from common.ddb_service.client import DynamoDbUtilsService from common.response import ok, not_found from common.util import publish_msg, generate_presigned_url_for_key, record_seconds_metrics from inferences.inference_libs import update_table_by_pk from libs.data_types import TrainJob, TrainJobStatus, CheckPoint, CheckPointStatus from libs.enums import ServiceType tracer = Tracer() logger = logging.getLogger(__name__) logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR) train_table = os.environ.get('TRAINING_JOB_TABLE') checkpoint_table = os.environ.get('CHECKPOINT_TABLE') user_topic_arn = os.environ.get('USER_EMAIL_TOPIC_ARN') bucket_name = os.environ.get("S3_BUCKET_NAME") sagemaker = boto3.client('sagemaker') s3 = boto3.client('s3') ddb_service = DynamoDbUtilsService(logger=logger) @tracer.capture_lambda_handler def handler(event, ctx): logger.info(json.dumps(event)) train_job_name = event['detail']['TrainingJobName'] rows = ddb_service.scan(train_table, filters={ 'sagemaker_train_name': train_job_name, }) logger.info(rows) if not rows or len(rows) == 0: return not_found(message=f'training job {train_job_name} is not found') training_job = TrainJob(**(ddb_service.deserialize(rows[0]))) logger.info(training_job) check_status(training_job) return ok() # sfn def check_status(training_job: TrainJob): resp = sagemaker.describe_training_job( TrainingJobName=training_job.sagemaker_train_name ) logger.info(resp) training_job_status = resp['TrainingJobStatus'] secondary_status = resp['SecondaryStatus'] update_table_by_pk(table_name=train_table, pk='id', id=training_job.id, key='job_status', value=secondary_status) if training_job_status == 'Failed' or training_job_status == 'Stopped': if 'FailureReason' in resp: err_msg = resp['FailureReason'] training_job.params['resp'] = { 'status': 'Failed', 'error_msg': err_msg, 'raw_resp': resp } if training_job_status == 'Completed': float_timestamp = float(training_job.timestamp) timestamp = datetime.datetime.fromtimestamp(float_timestamp).isoformat() record_seconds_metrics(start_time=timestamp, metric_name='TrainingLatency', service=ServiceType.SD.value) try: notify_user(training_job) except Exception as e: logger.error(e, exc_info=True) prefix = f"Stable-diffusion/checkpoint/custom/{training_job.id}" s3_resp = s3.list_objects( Bucket=bucket_name, Prefix=prefix, ) if 'Contents' in s3_resp and len(s3_resp['Contents']) > 0: for obj in s3_resp['Contents']: checkpoint_name = obj['Key'].replace(f'{prefix}/', "") logger.info(f'checkpoint_name: {checkpoint_name}') insert_ckpt(checkpoint_name, training_job) logs = get_logs(training_job.id) update_table_by_pk(table_name=train_table, pk='id', id=training_job.id, key='logs', value=logs) else: update_table_by_pk(table_name=train_table, pk='id', id=training_job.id, key='job_status', value=TrainJobStatus.Fail) training_job.params['resp'] = { 'raw_resp': resp } ddb_service.update_item( table=train_table, key={'id': training_job.id}, field_name='params', value=training_job.params ) return def get_logs(job_id: str): prefix = f"kohya/train/{job_id}/logs/" s3_resp = s3.list_objects( Bucket=bucket_name, Prefix=prefix, ) logs = [] if 'Contents' in s3_resp and len(s3_resp['Contents']) > 0: for obj in s3_resp['Contents']: logs.append(obj['Key'].replace(prefix, '')) return logs def get_logs_presign(job_id, logs): if len(logs) == 0: return [] presign_logs = [] for filename in logs: presign_logs.append({ 'filename': filename, 'url': generate_presigned_url_for_key(f"kohya/train/{job_id}/logs/{filename}") }) return presign_logs def insert_ckpt(output_name, job: TrainJob): raw_ckpts = ddb_service.scan(checkpoint_table) for r in raw_ckpts: ckpt = CheckPoint(**(ddb_service.deserialize(r))) if output_name in ckpt.checkpoint_names: return checkpoint = CheckPoint( id=str(uuid.uuid4()), checkpoint_type=const.CheckPointType.LORA, checkpoint_names=[output_name], s3_location=f"s3://{bucket_name}/Stable-diffusion/checkpoint/custom/{job.id}", checkpoint_status=CheckPointStatus.Active, timestamp=datetime.datetime.now().timestamp(), allowed_roles_or_users=job.allowed_roles_or_users, ) ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__) def notify_user(train_job: TrainJob): publish_msg( topic_arn=user_topic_arn, subject=f'Create Model Job {train_job.sagemaker_train_name} {train_job.job_status}', msg=f'to be done with resp: \n {train_job.job_status}' ) # todo: find out msg return 'job completed'