import datetime import json import logging import os import base64 import time from dataclasses import dataclass import boto3 import tarfile from typing import Any, List, Optional import sagemaker from common.ddb_service.client import DynamoDbUtilsService from common.stepfunction_service.client import StepFunctionUtilsService from common.util import load_json_from_s3, publish_msg, save_json_to_file from common_tools import split_s3_path, DecimalEncoder from common.util import get_s3_presign_urls from _types import TrainJob, TrainJobStatus, Model, CreateModelStatus, CheckPoint, CheckPointStatus bucket_name = os.environ.get('S3_BUCKET') train_table = os.environ.get('TRAIN_TABLE') model_table = os.environ.get('MODEL_TABLE') checkpoint_table = os.environ.get('CHECKPOINT_TABLE') instance_type = os.environ.get('INSTANCE_TYPE') sagemaker_role_arn = os.environ.get('TRAIN_JOB_ROLE') image_uri = os.environ.get('TRAIN_ECR_URL') # e.g. "648149843064.dkr.ecr.us-east-1.amazonaws.com/dreambooth-training-repo" training_stepfunction_arn = os.environ.get('TRAINING_SAGEMAKER_ARN') user_topic_arn = os.environ.get('USER_EMAIL_TOPIC_ARN') region_name = os.environ['AWS_REGION'] logger = logging.getLogger('boto3') ddb_service = DynamoDbUtilsService(logger=logger) s3 = boto3.client('s3', region_name=region_name) @dataclass class Event: train_type: str model_id: str params: dict[str, Any] filenames: Optional[List[str]] = None # POST /train def create_train_job_api(raw_event, context): request_id = context.aws_request_id event = Event(**raw_event) _type = event.train_type try: model_raw = ddb_service.get_item(table=model_table, key_values={ 'id': event.model_id }) if model_raw is None: return { 'statusCode': 500, 'error': f'model with id {event.model_id} is not found' } model = Model(**model_raw) if model.job_status != CreateModelStatus.Complete: return { 'statusCode': 500, 'error': f'model {model.id} is in {model.job_status.value} state, not valid to be used for train' } base_key = f'{_type}/train/{model.name}/{request_id}' input_location = f'{base_key}/input' presign_url_map = None if event.filenames is None: # Invoked from api, no config file is defined in the parameters json_file_name = 'db_config_cloud.json' tar_file_name = 'db_config.tar' tar_file_content = f'/tmp/models/sagemaker_dreambooth/{model.name}' tar_file_path = f'/tmp/{tar_file_name}' db_config_json = load_json_from_s3(bucket_name, 'template/' + json_file_name) # Merge user parameter, if no config_params is defined, use the default value in S3 bucket if "config_params" in event.params: db_config_json.update(event.params["config_params"]) # Add model parameters into train params event.params["training_params"]["model_name"] = model.name event.params["training_params"]["model_type"] = model.model_type event.params["training_params"]["s3_model_path"] = model.output_s3_location # Upload the merged JSON string to the S3 bucket as a tar file try: if not os.path.exists(tar_file_content): os.makedirs(tar_file_content) saved_path = save_json_to_file(db_config_json, tar_file_content, json_file_name) print(f'file saved to {saved_path}') with tarfile.open('/tmp/' + tar_file_name, 'w') as tar: # Add the contents of 'models' directory to the tar file without including the /tmp itself tar.add(tar_file_content, arcname=f'models/sagemaker_dreambooth/{model.name}') s3.upload_file(tar_file_path, bucket_name, os.path.join(input_location, tar_file_name)) logger.info(f"Tar file '{tar_file_name}' uploaded to '{bucket_name}' successfully.") except Exception as e: raise RuntimeError(f"Error uploading JSON file to S3: {e}") else: presign_url_map = get_s3_presign_urls(bucket_name=bucket_name, base_key=input_location, filenames=event.filenames) checkpoint = CheckPoint( id=request_id, checkpoint_type=event.train_type, s3_location=f's3://{bucket_name}/{base_key}/output', checkpoint_status=CheckPointStatus.Initial, timestamp=datetime.datetime.now().timestamp() ) ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__) train_input_s3_location = f's3://{bucket_name}/{input_location}' train_job = TrainJob( id=request_id, model_id=event.model_id, job_status=TrainJobStatus.Initial, params=event.params, train_type=event.train_type, input_s3_location=train_input_s3_location, checkpoint_id=checkpoint.id, timestamp=datetime.datetime.now().timestamp() ) ddb_service.put_items(table=train_table, entries=train_job.__dict__) return { 'statusCode': 200, 'job': { 'id': train_job.id, 'status': train_job.job_status.value, 'trainType': train_job.train_type, 'params': train_job.params, 'input_location': train_input_s3_location, }, 's3PresignUrl': presign_url_map } except Exception as e: logger.error(e) return { 'statusCode': 200, 'error': str(e) } # GET /trains def list_all_train_jobs_api(event, context): _filter = {} if 'queryStringParameters' not in event: return { 'statusCode': '500', 'error': 'query parameter status and types are needed' } parameters = event['queryStringParameters'] if 'types' in parameters and len(parameters['types']) > 0: _filter['train_type'] = parameters['types'] if 'status' in parameters and len(parameters['status']) > 0: _filter['job_status'] = parameters['status'] resp = ddb_service.scan(table=train_table, filters=_filter) if resp is None or len(resp) == 0: return { 'statusCode': 200, 'trainJobs': [] } train_jobs = [] for tr in resp: train_job = TrainJob(**(ddb_service.deserialize(tr))) model_name = 'not_applied' if 'training_params' in train_job.params and 'model_name' in train_job.params['training_params']: model_name = train_job.params['training_params']['model_name'] train_jobs.append({ 'id': train_job.id, 'modelName': model_name, 'status': train_job.job_status.value, 'trainType': train_job.train_type, 'created': train_job.timestamp, 'sagemakerTrainName': train_job.sagemaker_train_name, }) return { 'statusCode': 200, 'trainJobs': train_jobs } # PUT /train used to kickoff a train job step function def update_train_job_api(event, context): if 'status' in event and 'train_job_id' in event and event['status'] == TrainJobStatus.Training.value: return _start_train_job(event['train_job_id']) return { 'statusCode': 200, 'msg': f'not implemented for train job status {event["status"]}' } def _start_train_job(train_job_id: str): raw_train_job = ddb_service.get_item(table=train_table, key_values={ 'id': train_job_id }) if raw_train_job is None or len(raw_train_job) == 0: return { 'statusCode': 500, 'error': f'no such train job with id({train_job_id})' } train_job = TrainJob(**raw_train_job) model_raw = ddb_service.get_item(table=model_table, key_values={ 'id': train_job.model_id }) if model_raw is None: return { 'statusCode': 500, 'error': f'model with id {train_job.model_id} is not found' } model = Model(**model_raw) raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={ 'id': train_job.checkpoint_id }) if raw_checkpoint is None: return { 'statusCode': 500, 'error': f'checkpoint with id {train_job.checkpoint_id} is not found' } checkpoint = CheckPoint(**raw_checkpoint) try: # JSON encode hyperparameters def json_encode_hyperparameters(hyperparameters): new_params = {} for k, v in hyperparameters.items(): json_v = json.dumps(v, cls=DecimalEncoder) v_bytes = json_v.encode('ascii') base64_bytes = base64.b64encode(v_bytes) base64_v = base64_bytes.decode('ascii') new_params[k] = base64_v return new_params hyperparameters = json_encode_hyperparameters({ "sagemaker_program": "extensions/sd-webui-sagemaker/sagemaker_entrypoint_json.py", "params": train_job.params, "s3-input-path": train_job.input_s3_location, "s3-output-path": checkpoint.s3_location, }) final_instance_type = instance_type if 'training_params' in train_job.params \ and 'training_instance_type' in train_job.params['training_params'] and \ train_job.params['training_params']['training_instance_type']: final_instance_type = train_job.params['training_params']['training_instance_type'] est = sagemaker.estimator.Estimator( image_uri, sagemaker_role_arn, instance_count=1, instance_type=final_instance_type, volume_size=125, base_job_name=f'{model.name}', hyperparameters=hyperparameters, job_id=train_job.id, ) est.fit(wait=False) while not est._current_job_name: time.sleep(1) train_job.sagemaker_train_name = est._current_job_name # trigger stepfunction stepfunctions_client = StepFunctionUtilsService(logger=logger) sfn_input = { 'train_job_id': train_job.id, 'train_job_name': train_job.sagemaker_train_name } sfn_arn = stepfunctions_client.invoke_step_function(training_stepfunction_arn, sfn_input) # todo: use batch update, this is ugly!!! search_key = {'id': train_job.id} ddb_service.update_item( table=train_table, key=search_key, field_name='sagemaker_train_name', value=est._current_job_name ) train_job.job_status = TrainJobStatus.Training ddb_service.update_item( table=train_table, key=search_key, field_name='job_status', value=TrainJobStatus.Training.value ) train_job.sagemaker_sfn_arn = sfn_arn ddb_service.update_item( table=train_table, key=search_key, field_name='sagemaker_sfn_arn', value=sfn_arn ) return { 'statusCode': 200, 'job': { 'id': train_job.id, 'status': train_job.job_status.value, 'created': train_job.timestamp, 'trainType': train_job.train_type, 'params': train_job.params, 'input_location': train_job.input_s3_location }, } except Exception as e: print(e) return { 'statusCode': 500, 'error': str(e) } # sfn def check_train_job_status(event, context): import boto3 boto3_sagemaker = boto3.client('sagemaker') train_job_name = event['train_job_name'] train_job_id = event['train_job_id'] resp = boto3_sagemaker.describe_training_job( TrainingJobName=train_job_name ) training_job_status = resp['TrainingJobStatus'] event['status'] = training_job_status raw_train_job = ddb_service.get_item(table=train_table, key_values={ 'id': train_job_id, }) if raw_train_job is None or len(raw_train_job) == 0: event['status'] = 'Failed' return { 'statusCode': 500, 'msg': f'no such training job find in ddb id[{train_job_id}]' } training_job = TrainJob(**raw_train_job) if training_job_status == 'InProgress' or training_job_status == 'Stopping': return event if training_job_status == 'Failed' or training_job_status == 'Stopped': training_job.job_status = TrainJobStatus.Fail if 'FailureReason' in resp: err_msg = resp['FailureReason'] training_job.params['resp'] = { 'status': 'Failed', 'error_msg': err_msg, 'raw_resp': resp } if training_job_status == 'Completed': training_job.job_status = TrainJobStatus.Complete # todo: update checkpoints raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={ 'id': training_job.checkpoint_id }) if raw_checkpoint is None or len(raw_checkpoint) == 0: # todo: or create new one return 'failed because no checkpoint, not normal' checkpoint = CheckPoint(**raw_checkpoint) checkpoint.checkpoint_status = CheckPointStatus.Active s3 = boto3.client('s3') bucket, key = split_s3_path(checkpoint.s3_location) s3_resp = s3.list_objects( Bucket=bucket, Prefix=key, ) checkpoint.checkpoint_names = [] if 'Contents' in s3_resp and len(s3_resp['Contents']) > 0: for obj in s3_resp['Contents']: checkpoint_name = obj['Key'].replace(f'{key}/', "") checkpoint.checkpoint_names.append(checkpoint_name) else: training_job.job_status = TrainJobStatus.Fail checkpoint.checkpoint_status = CheckPointStatus.Initial ddb_service.update_item( table=checkpoint_table, key={ 'id': checkpoint.id }, field_name='checkpoint_status', value=checkpoint.checkpoint_status.value ) ddb_service.update_item( table=checkpoint_table, key={ 'id': checkpoint.id }, field_name='checkpoint_names', value=checkpoint.checkpoint_names ) training_job.params['resp'] = { 'raw_resp': resp } # fixme: this is ugly ddb_service.update_item( table=train_table, key={'id': training_job.id}, field_name='job_status', value=training_job.job_status.value ) ddb_service.update_item( table=train_table, key={'id': training_job.id}, field_name='params', value=training_job.params ) return event # sfn def process_train_job_result(event, context): train_job_id = event['train_job_id'] raw_train_job = ddb_service.get_item(table=train_table, key_values={ 'id': train_job_id, }) if raw_train_job is None or len(raw_train_job) == 0: return { 'statusCode': 500, 'msg': f'no such training job find in ddb id[{train_job_id}]' } train_job = TrainJob(**raw_train_job) publish_msg( topic_arn=user_topic_arn, subject=f'Create Model Job {train_job.sagemaker_train_name} {train_job.job_status}', msg=f'to be done with resp: \n {train_job.job_status}' ) # todo: find out msg return 'job completed'