stable-diffusion-aws-extension/middleware_api/lambda/model_and_train/train_api.py

475 lines
17 KiB
Python

import datetime
import json
import logging
import os
import base64
import time
from dataclasses import dataclass
import boto3
import tarfile
from typing import Any, List, Optional
import sagemaker
from common.ddb_service.client import DynamoDbUtilsService
from common.response import ok, bad_request
from common.stepfunction_service.client import StepFunctionUtilsService
from common.util import load_json_from_s3, publish_msg, save_json_to_file
from common_tools import split_s3_path, DecimalEncoder
from common.util import get_s3_presign_urls
from _types import TrainJob, TrainJobStatus, Model, CreateModelStatus, CheckPoint, CheckPointStatus
from multi_users.utils import get_permissions_by_username, get_user_roles, check_user_permissions
bucket_name = os.environ.get('S3_BUCKET')
train_table = os.environ.get('TRAIN_TABLE')
model_table = os.environ.get('MODEL_TABLE')
checkpoint_table = os.environ.get('CHECKPOINT_TABLE')
user_table = os.environ.get('MULTI_USER_TABLE')
instance_type = os.environ.get('INSTANCE_TYPE')
sagemaker_role_arn = os.environ.get('TRAIN_JOB_ROLE')
image_uri = os.environ.get('TRAIN_ECR_URL') # e.g. "648149843064.dkr.ecr.us-east-1.amazonaws.com/dreambooth-training-repo"
training_stepfunction_arn = os.environ.get('TRAINING_SAGEMAKER_ARN')
user_topic_arn = os.environ.get('USER_EMAIL_TOPIC_ARN')
logger = logging.getLogger('boto3')
ddb_service = DynamoDbUtilsService(logger=logger)
@dataclass
class Event:
train_type: str
model_id: str
params: dict[str, Any]
creator: str
filenames: Optional[List[str]] = None
# POST /train
def create_train_job_api(raw_event, context):
request_id = context.aws_request_id
event = Event(**raw_event)
_type = event.train_type
try:
creator_permissions = get_permissions_by_username(ddb_service, user_table, event.creator)
if 'train' not in creator_permissions \
or ('all' not in creator_permissions['train'] and 'create' not in creator_permissions['train']):
return {
'statusCode': 400,
'errMsg': f'user {event.creator} has not permission to create a train job'
}
model_raw = ddb_service.get_item(table=model_table, key_values={
'id': event.model_id
})
# if model is not found, model_raw is {}
if model_raw == {}:
return {
'statusCode': 500,
'error': f'model with id {event.model_id} is not found'
}
model = Model(**model_raw)
if model.job_status != CreateModelStatus.Complete:
return {
'statusCode': 500,
'error': f'model {model.id} is in {model.job_status.value} state, not valid to be used for train'
}
base_key = f'{_type}/train/{model.name}/{request_id}'
input_location = f'{base_key}/input'
presign_url_map = None
if event.filenames is None:
# Invoked from api, no config file is defined in the parameters
json_file_name = 'db_config_cloud.json'
tar_file_name = 'db_config.tar'
tar_file_content = f'/tmp/models/sagemaker_dreambooth/{model.name}'
tar_file_path = f'/tmp/{tar_file_name}'
db_config_json = load_json_from_s3(bucket_name, 'template/' + json_file_name)
# Merge user parameter, if no config_params is defined, use the default value in S3 bucket
if "config_params" in event.params:
db_config_json.update(event.params["config_params"])
# Add model parameters into train params
event.params["training_params"]["model_name"] = model.name
event.params["training_params"]["model_type"] = model.model_type
event.params["training_params"]["s3_model_path"] = model.output_s3_location
# Upload the merged JSON string to the S3 bucket as a tar file
try:
if not os.path.exists(tar_file_content):
os.makedirs(tar_file_content)
saved_path = save_json_to_file(db_config_json, tar_file_content, json_file_name)
print(f'file saved to {saved_path}')
with tarfile.open('/tmp/' + tar_file_name, 'w') as tar:
# Add the contents of 'models' directory to the tar file without including the /tmp itself
tar.add(tar_file_content, arcname=f'models/sagemaker_dreambooth/{model.name}')
s3 = boto3.client('s3')
s3.upload_file(tar_file_path, bucket_name, os.path.join(input_location, tar_file_name))
logger.info(f"Tar file '{tar_file_name}' uploaded to '{bucket_name}' successfully.")
except Exception as e:
raise RuntimeError(f"Error uploading JSON file to S3: {e}")
else:
presign_url_map = get_s3_presign_urls(bucket_name=bucket_name, base_key=input_location, filenames=event.filenames)
user_roles = get_user_roles(ddb_service, user_table, event.creator)
checkpoint = CheckPoint(
id=request_id,
checkpoint_type=event.train_type,
s3_location=f's3://{bucket_name}/{base_key}/output',
checkpoint_status=CheckPointStatus.Initial,
timestamp=datetime.datetime.now().timestamp(),
allowed_roles_or_users=user_roles
)
ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__)
train_input_s3_location = f's3://{bucket_name}/{input_location}'
train_job = TrainJob(
id=request_id,
model_id=event.model_id,
job_status=TrainJobStatus.Initial,
params=event.params,
train_type=event.train_type,
input_s3_location=train_input_s3_location,
checkpoint_id=checkpoint.id,
timestamp=datetime.datetime.now().timestamp(),
allowed_roles_or_users=[event.creator]
)
ddb_service.put_items(table=train_table, entries=train_job.__dict__)
return {
'statusCode': 200,
'job': {
'id': train_job.id,
'status': train_job.job_status.value,
'trainType': train_job.train_type,
'params': train_job.params,
'input_location': train_input_s3_location,
},
's3PresignUrl': presign_url_map
}
except Exception as e:
logger.error(e)
return {
'statusCode': 200,
'error': str(e)
}
# GET /trains
def list_all_train_jobs_api(event, context):
_filter = {}
parameters = event['queryStringParameters']
if parameters:
if 'types' in parameters and len(parameters['types']) > 0:
_filter['train_type'] = parameters['types']
if 'status' in parameters and len(parameters['status']) > 0:
_filter['job_status'] = parameters['status']
resp = ddb_service.scan(table=train_table, filters=_filter)
if resp is None or len(resp) == 0:
return ok(data={'trainJobs': []})
requestor_name = event['requestContext']['authorizer']['username']
try:
requestor_permissions = get_permissions_by_username(ddb_service, user_table, requestor_name)
requestor_roles = get_user_roles(ddb_service=ddb_service, user_table_name=user_table, username=requestor_name)
if 'train' not in requestor_permissions or \
('all' not in requestor_permissions['train'] and 'list' not in requestor_permissions['train']):
return bad_request(message='user has no permission to train')
train_jobs = []
for tr in resp:
train_job = TrainJob(**(ddb_service.deserialize(tr)))
model_name = 'not_applied'
if 'training_params' in train_job.params and 'model_name' in train_job.params['training_params']:
model_name = train_job.params['training_params']['model_name']
train_job_dto = {
'id': train_job.id,
'modelName': model_name,
'status': train_job.job_status.value,
'trainType': train_job.train_type,
'created': train_job.timestamp,
'sagemakerTrainName': train_job.sagemaker_train_name,
}
if train_job.allowed_roles_or_users and check_user_permissions(train_job.allowed_roles_or_users, requestor_roles, requestor_name):
train_jobs.append(train_job_dto)
elif not train_job.allowed_roles_or_users and \
'user' in requestor_permissions and \
'all' in requestor_permissions['user']:
# superuser can view the legacy data
train_jobs.append(train_job_dto)
return ok(data={'trainJobs': train_jobs}, decimal=True)
except Exception as e:
logger.error(e)
return bad_request(message=str(e))
# PUT /train used to kickoff a train job step function
def update_train_job_api(event, context):
if 'status' in event and 'train_job_id' in event and event['status'] == TrainJobStatus.Training.value:
return _start_train_job(event['train_job_id'])
return {
'statusCode': 200,
'msg': f'not implemented for train job status {event["status"]}'
}
def _start_train_job(train_job_id: str):
raw_train_job = ddb_service.get_item(table=train_table, key_values={
'id': train_job_id
})
if raw_train_job is None or len(raw_train_job) == 0:
return {
'statusCode': 500,
'error': f'no such train job with id({train_job_id})'
}
train_job = TrainJob(**raw_train_job)
model_raw = ddb_service.get_item(table=model_table, key_values={
'id': train_job.model_id
})
if model_raw is None:
return {
'statusCode': 500,
'error': f'model with id {train_job.model_id} is not found'
}
model = Model(**model_raw)
raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={
'id': train_job.checkpoint_id
})
if raw_checkpoint is None:
return {
'statusCode': 500,
'error': f'checkpoint with id {train_job.checkpoint_id} is not found'
}
checkpoint = CheckPoint(**raw_checkpoint)
try:
# JSON encode hyperparameters
def json_encode_hyperparameters(hyperparameters):
new_params = {}
for k, v in hyperparameters.items():
json_v = json.dumps(v, cls=DecimalEncoder)
v_bytes = json_v.encode('ascii')
base64_bytes = base64.b64encode(v_bytes)
base64_v = base64_bytes.decode('ascii')
new_params[k] = base64_v
return new_params
hyperparameters = json_encode_hyperparameters({
"sagemaker_program": "extensions/sd-webui-sagemaker/sagemaker_entrypoint_json.py",
"params": train_job.params,
"s3-input-path": train_job.input_s3_location,
"s3-output-path": checkpoint.s3_location,
})
final_instance_type = instance_type
if 'training_params' in train_job.params \
and 'training_instance_type' in train_job.params['training_params'] and \
train_job.params['training_params']['training_instance_type']:
final_instance_type = train_job.params['training_params']['training_instance_type']
est = sagemaker.estimator.Estimator(
image_uri,
sagemaker_role_arn,
instance_count=1,
instance_type=final_instance_type,
volume_size=125,
base_job_name=f'{model.name}',
hyperparameters=hyperparameters,
job_id=train_job.id,
)
est.fit(wait=False)
while not est._current_job_name:
time.sleep(1)
train_job.sagemaker_train_name = est._current_job_name
# trigger stepfunction
stepfunctions_client = StepFunctionUtilsService(logger=logger)
sfn_input = {
'train_job_id': train_job.id,
'train_job_name': train_job.sagemaker_train_name
}
sfn_arn = stepfunctions_client.invoke_step_function(training_stepfunction_arn, sfn_input)
# todo: use batch update, this is ugly!!!
search_key = {'id': train_job.id}
ddb_service.update_item(
table=train_table,
key=search_key,
field_name='sagemaker_train_name',
value=est._current_job_name
)
train_job.job_status = TrainJobStatus.Training
ddb_service.update_item(
table=train_table,
key=search_key,
field_name='job_status',
value=TrainJobStatus.Training.value
)
train_job.sagemaker_sfn_arn = sfn_arn
ddb_service.update_item(
table=train_table,
key=search_key,
field_name='sagemaker_sfn_arn',
value=sfn_arn
)
return {
'statusCode': 200,
'job': {
'id': train_job.id,
'status': train_job.job_status.value,
'created': train_job.timestamp,
'trainType': train_job.train_type,
'params': train_job.params,
'input_location': train_job.input_s3_location
},
}
except Exception as e:
print(e)
return {
'statusCode': 500,
'error': str(e)
}
# sfn
def check_train_job_status(event, context):
import boto3
boto3_sagemaker = boto3.client('sagemaker')
train_job_name = event['train_job_name']
train_job_id = event['train_job_id']
resp = boto3_sagemaker.describe_training_job(
TrainingJobName=train_job_name
)
training_job_status = resp['TrainingJobStatus']
event['status'] = training_job_status
raw_train_job = ddb_service.get_item(table=train_table, key_values={
'id': train_job_id,
})
if raw_train_job is None or len(raw_train_job) == 0:
event['status'] = 'Failed'
return {
'statusCode': 500,
'msg': f'no such training job find in ddb id[{train_job_id}]'
}
training_job = TrainJob(**raw_train_job)
if training_job_status == 'InProgress' or training_job_status == 'Stopping':
return event
if training_job_status == 'Failed' or training_job_status == 'Stopped':
training_job.job_status = TrainJobStatus.Fail
if 'FailureReason' in resp:
err_msg = resp['FailureReason']
training_job.params['resp'] = {
'status': 'Failed',
'error_msg': err_msg,
'raw_resp': resp
}
if training_job_status == 'Completed':
training_job.job_status = TrainJobStatus.Complete
# todo: update checkpoints
raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={
'id': training_job.checkpoint_id
})
if raw_checkpoint is None or len(raw_checkpoint) == 0:
# todo: or create new one
return 'failed because no checkpoint, not normal'
checkpoint = CheckPoint(**raw_checkpoint)
checkpoint.checkpoint_status = CheckPointStatus.Active
s3 = boto3.client('s3')
bucket, key = split_s3_path(checkpoint.s3_location)
s3_resp = s3.list_objects(
Bucket=bucket,
Prefix=key,
)
checkpoint.checkpoint_names = []
if 'Contents' in s3_resp and len(s3_resp['Contents']) > 0:
for obj in s3_resp['Contents']:
checkpoint_name = obj['Key'].replace(f'{key}/', "")
checkpoint.checkpoint_names.append(checkpoint_name)
else:
training_job.job_status = TrainJobStatus.Fail
checkpoint.checkpoint_status = CheckPointStatus.Initial
ddb_service.update_item(
table=checkpoint_table,
key={
'id': checkpoint.id
},
field_name='checkpoint_status',
value=checkpoint.checkpoint_status.value
)
ddb_service.update_item(
table=checkpoint_table,
key={
'id': checkpoint.id
},
field_name='checkpoint_names',
value=checkpoint.checkpoint_names
)
training_job.params['resp'] = {
'raw_resp': resp
}
# fixme: this is ugly
ddb_service.update_item(
table=train_table,
key={'id': training_job.id},
field_name='job_status',
value=training_job.job_status.value
)
ddb_service.update_item(
table=train_table,
key={'id': training_job.id},
field_name='params',
value=training_job.params
)
return event
# sfn
def process_train_job_result(event, context):
train_job_id = event['train_job_id']
raw_train_job = ddb_service.get_item(table=train_table, key_values={
'id': train_job_id,
})
if raw_train_job is None or len(raw_train_job) == 0:
return {
'statusCode': 500,
'msg': f'no such training job find in ddb id[{train_job_id}]'
}
train_job = TrainJob(**raw_train_job)
publish_msg(
topic_arn=user_topic_arn,
subject=f'Create Model Job {train_job.sagemaker_train_name} {train_job.job_status}',
msg=f'to be done with resp: \n {train_job.job_status}'
) # todo: find out msg
return 'job completed'