stable-diffusion-aws-extension/middleware_api/lambda/model_and_train/train_api.py

418 lines
14 KiB
Python

import datetime
import json
import logging
import os
import base64
import time
from dataclasses import dataclass
from typing import Any
import sagemaker
from common.ddb_service.client import DynamoDbUtilsService
from common.stepfunction_service.client import StepFunctionUtilsService
from common.util import publish_msg
from common_tools import split_s3_path, DecimalEncoder
from common.util import get_s3_presign_urls
from _types import TrainJob, TrainJobStatus, Model, CreateModelStatus, CheckPoint, CheckPointStatus
bucket_name = os.environ.get('S3_BUCKET')
train_table = os.environ.get('TRAIN_TABLE')
model_table = os.environ.get('MODEL_TABLE')
checkpoint_table = os.environ.get('CHECKPOINT_TABLE')
instance_type = os.environ.get('INSTANCE_TYPE')
sagemaker_role_arn = os.environ.get('TRAIN_JOB_ROLE')
image_uri = os.environ.get('TRAIN_ECR_URL') # e.g. "648149843064.dkr.ecr.us-east-1.amazonaws.com/dreambooth-training-repo"
training_stepfunction_arn = os.environ.get('TRAINING_SAGEMAKER_ARN')
user_topic_arn = os.environ.get('USER_EMAIL_TOPIC_ARN')
logger = logging.getLogger('boto3')
ddb_service = DynamoDbUtilsService(logger=logger)
@dataclass
class Event:
train_type: str
model_id: str
params: dict[str, Any]
filenames: [str]
# POST /train
def create_train_job_api(raw_event, context):
request_id = context.aws_request_id
event = Event(**raw_event)
_type = event.train_type
try:
model_raw = ddb_service.get_item(table=model_table, key_values={
'id': event.model_id
})
if model_raw is None:
return {
'statusCode': 500,
'error': f'model with id {event.model_id} is not found'
}
model = Model(**model_raw)
if model.job_status != CreateModelStatus.Complete:
return {
'statusCode': 500,
'error': f'model {model.id} is in {model.job_status.value} state, not valid to be used for train'
}
base_key = f'{_type}/train/{model.name}/{request_id}'
input_location = f'{base_key}/input'
presign_url_map = get_s3_presign_urls(bucket_name=bucket_name, base_key=input_location, filenames=event.filenames)
checkpoint = CheckPoint(
id=request_id,
checkpoint_type=event.train_type,
s3_location=f's3://{bucket_name}/{base_key}/output',
checkpoint_status=CheckPointStatus.Initial,
timestamp=datetime.datetime.now().timestamp()
)
ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__)
train_input_s3_location = f's3://{bucket_name}/{input_location}'
train_job = TrainJob(
id=request_id,
model_id=event.model_id,
job_status=TrainJobStatus.Initial,
params=event.params,
train_type=event.train_type,
input_s3_location=train_input_s3_location,
checkpoint_id=checkpoint.id,
timestamp=datetime.datetime.now().timestamp()
)
ddb_service.put_items(table=train_table, entries=train_job.__dict__)
return {
'statusCode': 200,
'job': {
'id': train_job.id,
'status': train_job.job_status.value,
'trainType': train_job.train_type,
'params': train_job.params,
'input_location': train_input_s3_location,
},
's3PresignUrl': presign_url_map
}
except Exception as e:
logger.error(e)
return {
'statusCode': 200,
'error': str(e)
}
# GET /trains
def list_all_train_jobs_api(event, context):
_filter = {}
if 'queryStringParameters' not in event:
return {
'statusCode': '500',
'error': 'query parameter status and types are needed'
}
parameters = event['queryStringParameters']
if 'types' in parameters and len(parameters['types']) > 0:
_filter['train_type'] = parameters['types']
if 'status' in parameters and len(parameters['status']) > 0:
_filter['job_status'] = parameters['status']
resp = ddb_service.scan(table=train_table, filters=_filter)
if resp is None or len(resp) == 0:
return {
'statusCode': 200,
'trainJobs': []
}
train_jobs = []
for tr in resp:
train_job = TrainJob(**(ddb_service.deserialize(tr)))
model_name = 'not_applied'
if 'training_params' in train_job.params and 'model_name' in train_job.params['training_params']:
model_name = train_job.params['training_params']['model_name']
train_jobs.append({
'id': train_job.id,
'modelName': model_name,
'status': train_job.job_status.value,
'trainType': train_job.train_type,
'created': train_job.timestamp,
'sagemakerTrainName': train_job.sagemaker_train_name,
})
return {
'statusCode': 200,
'trainJobs': train_jobs
}
# PUT /train used to kickoff a train job step function
def update_train_job_api(event, context):
if 'status' in event and 'train_job_id' in event and event['status'] == TrainJobStatus.Training.value:
return _start_train_job(event['train_job_id'])
return {
'statusCode': 200,
'msg': f'not implemented for train job status {event["status"]}'
}
def _start_train_job(train_job_id: str):
raw_train_job = ddb_service.get_item(table=train_table, key_values={
'id': train_job_id
})
if raw_train_job is None or len(raw_train_job) == 0:
return {
'statusCode': 500,
'error': f'no such train job with id({train_job_id})'
}
train_job = TrainJob(**raw_train_job)
model_raw = ddb_service.get_item(table=model_table, key_values={
'id': train_job.model_id
})
if model_raw is None:
return {
'statusCode': 500,
'error': f'model with id {train_job.model_id} is not found'
}
model = Model(**model_raw)
raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={
'id': train_job.checkpoint_id
})
if raw_checkpoint is None:
return {
'statusCode': 500,
'error': f'checkpoint with id {train_job.checkpoint_id} is not found'
}
checkpoint = CheckPoint(**raw_checkpoint)
try:
# JSON encode hyperparameters
def json_encode_hyperparameters(hyperparameters):
new_params = {}
for k, v in hyperparameters.items():
json_v = json.dumps(v, cls=DecimalEncoder)
v_bytes = json_v.encode('ascii')
base64_bytes = base64.b64encode(v_bytes)
base64_v = base64_bytes.decode('ascii')
new_params[k] = base64_v
return new_params
hyperparameters = json_encode_hyperparameters({
"sagemaker_program": "extensions/sd-webui-sagemaker/sagemaker_entrypoint_json.py",
"params": train_job.params,
"s3-input-path": train_job.input_s3_location,
"s3-output-path": checkpoint.s3_location,
})
final_instance_type = instance_type
if 'training_params' in train_job.params \
and 'training_instance_type' in train_job.params['training_params'] and \
train_job.params['training_params']['training_instance_type']:
final_instance_type = train_job.params['training_params']['training_instance_type']
est = sagemaker.estimator.Estimator(
image_uri,
sagemaker_role_arn,
instance_count=1,
instance_type=final_instance_type,
volume_size=125,
base_job_name=f'{model.name}',
hyperparameters=hyperparameters,
job_id=train_job.id,
)
est.fit(wait=False)
while not est._current_job_name:
time.sleep(1)
train_job.sagemaker_train_name = est._current_job_name
# trigger stepfunction
stepfunctions_client = StepFunctionUtilsService(logger=logger)
sfn_input = {
'train_job_id': train_job.id,
'train_job_name': train_job.sagemaker_train_name
}
sfn_arn = stepfunctions_client.invoke_step_function(training_stepfunction_arn, sfn_input)
# todo: use batch update, this is ugly!!!
search_key = {'id': train_job.id}
ddb_service.update_item(
table=train_table,
key=search_key,
field_name='sagemaker_train_name',
value=est._current_job_name
)
train_job.job_status = TrainJobStatus.Training
ddb_service.update_item(
table=train_table,
key=search_key,
field_name='job_status',
value=TrainJobStatus.Training.value
)
train_job.sagemaker_sfn_arn = sfn_arn
ddb_service.update_item(
table=train_table,
key=search_key,
field_name='sagemaker_sfn_arn',
value=sfn_arn
)
return {
'statusCode': 200,
'job': {
'id': train_job.id,
'status': train_job.job_status.value,
'created': train_job.timestamp,
'trainType': train_job.train_type,
'params': train_job.params,
'input_location': train_job.input_s3_location
},
}
except Exception as e:
print(e)
return {
'statusCode': 500,
'error': str(e)
}
# sfn
def check_train_job_status(event, context):
import boto3
boto3_sagemaker = boto3.client('sagemaker')
train_job_name = event['train_job_name']
train_job_id = event['train_job_id']
resp = boto3_sagemaker.describe_training_job(
TrainingJobName=train_job_name
)
training_job_status = resp['TrainingJobStatus']
event['status'] = training_job_status
raw_train_job = ddb_service.get_item(table=train_table, key_values={
'id': train_job_id,
})
if raw_train_job is None or len(raw_train_job) == 0:
event['status'] = 'Failed'
return {
'statusCode': 500,
'msg': f'no such training job find in ddb id[{train_job_id}]'
}
training_job = TrainJob(**raw_train_job)
if training_job_status == 'InProgress' or training_job_status == 'Stopping':
return event
if training_job_status == 'Failed' or training_job_status == 'Stopped':
training_job.job_status = TrainJobStatus.Fail
if 'FailureReason' in resp:
err_msg = resp['FailureReason']
training_job.params['resp'] = {
'status': 'Failed',
'error_msg': err_msg,
'raw_resp': resp
}
if training_job_status == 'Completed':
training_job.job_status = TrainJobStatus.Complete
# todo: update checkpoints
raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={
'id': training_job.checkpoint_id
})
if raw_checkpoint is None or len(raw_checkpoint) == 0:
# todo: or create new one
return 'failed because no checkpoint, not normal'
checkpoint = CheckPoint(**raw_checkpoint)
checkpoint.checkpoint_status = CheckPointStatus.Active
s3 = boto3.client('s3')
bucket, key = split_s3_path(checkpoint.s3_location)
s3_resp = s3.list_objects(
Bucket=bucket,
Prefix=key,
)
checkpoint.checkpoint_names = []
if 'Contents' in s3_resp and len(s3_resp['Contents']) > 0:
for obj in s3_resp['Contents']:
checkpoint_name = obj['Key'].replace(f'{key}/', "")
checkpoint.checkpoint_names.append(checkpoint_name)
else:
training_job.job_status = TrainJobStatus.Fail
checkpoint.checkpoint_status = CheckPointStatus.Initial
ddb_service.update_item(
table=checkpoint_table,
key={
'id': checkpoint.id
},
field_name='checkpoint_status',
value=checkpoint.checkpoint_status.value
)
ddb_service.update_item(
table=checkpoint_table,
key={
'id': checkpoint.id
},
field_name='checkpoint_names',
value=checkpoint.checkpoint_names
)
training_job.params['resp'] = {
'raw_resp': resp
}
# fixme: this is ugly
ddb_service.update_item(
table=train_table,
key={'id': training_job.id},
field_name='job_status',
value=training_job.job_status.value
)
ddb_service.update_item(
table=train_table,
key={'id': training_job.id},
field_name='params',
value=training_job.params
)
return event
# sfn
def process_train_job_result(event, context):
train_job_id = event['train_job_id']
raw_train_job = ddb_service.get_item(table=train_table, key_values={
'id': train_job_id,
})
if raw_train_job is None or len(raw_train_job) == 0:
return {
'statusCode': 500,
'msg': f'no such training job find in ddb id[{train_job_id}]'
}
train_job = TrainJob(**raw_train_job)
publish_msg(
topic_arn=user_topic_arn,
subject=f'Create Model Job {train_job.sagemaker_train_name} {train_job.job_status}',
msg=f'to be done with resp: \n {train_job.job_status}'
) # todo: find out msg
return 'job completed'