475 lines
17 KiB
Python
475 lines
17 KiB
Python
import datetime
|
|
import json
|
|
import logging
|
|
import os
|
|
import base64
|
|
import time
|
|
from dataclasses import dataclass
|
|
import boto3
|
|
import tarfile
|
|
from typing import Any, List, Optional
|
|
import sagemaker
|
|
|
|
from common.ddb_service.client import DynamoDbUtilsService
|
|
from common.response import ok, bad_request
|
|
from common.stepfunction_service.client import StepFunctionUtilsService
|
|
from common.util import load_json_from_s3, publish_msg, save_json_to_file
|
|
from common_tools import split_s3_path, DecimalEncoder
|
|
from common.util import get_s3_presign_urls
|
|
from _types import TrainJob, TrainJobStatus, Model, CreateModelStatus, CheckPoint, CheckPointStatus
|
|
from multi_users.utils import get_permissions_by_username, get_user_roles, check_user_permissions
|
|
|
|
bucket_name = os.environ.get('S3_BUCKET')
|
|
train_table = os.environ.get('TRAIN_TABLE')
|
|
model_table = os.environ.get('MODEL_TABLE')
|
|
checkpoint_table = os.environ.get('CHECKPOINT_TABLE')
|
|
user_table = os.environ.get('MULTI_USER_TABLE')
|
|
instance_type = os.environ.get('INSTANCE_TYPE')
|
|
sagemaker_role_arn = os.environ.get('TRAIN_JOB_ROLE')
|
|
image_uri = os.environ.get('TRAIN_ECR_URL') # e.g. "648149843064.dkr.ecr.us-east-1.amazonaws.com/dreambooth-training-repo"
|
|
training_stepfunction_arn = os.environ.get('TRAINING_SAGEMAKER_ARN')
|
|
user_topic_arn = os.environ.get('USER_EMAIL_TOPIC_ARN')
|
|
logger = logging.getLogger('boto3')
|
|
ddb_service = DynamoDbUtilsService(logger=logger)
|
|
|
|
|
|
@dataclass
|
|
class Event:
|
|
train_type: str
|
|
model_id: str
|
|
params: dict[str, Any]
|
|
creator: str
|
|
filenames: Optional[List[str]] = None
|
|
|
|
|
|
# POST /train
|
|
def create_train_job_api(raw_event, context):
|
|
request_id = context.aws_request_id
|
|
event = Event(**raw_event)
|
|
_type = event.train_type
|
|
|
|
try:
|
|
creator_permissions = get_permissions_by_username(ddb_service, user_table, event.creator)
|
|
if 'train' not in creator_permissions \
|
|
or ('all' not in creator_permissions['train'] and 'create' not in creator_permissions['train']):
|
|
return {
|
|
'statusCode': 400,
|
|
'errMsg': f'user {event.creator} has not permission to create a train job'
|
|
}
|
|
|
|
model_raw = ddb_service.get_item(table=model_table, key_values={
|
|
'id': event.model_id
|
|
})
|
|
# if model is not found, model_raw is {}
|
|
if model_raw == {}:
|
|
return {
|
|
'statusCode': 500,
|
|
'error': f'model with id {event.model_id} is not found'
|
|
}
|
|
|
|
model = Model(**model_raw)
|
|
if model.job_status != CreateModelStatus.Complete:
|
|
return {
|
|
'statusCode': 500,
|
|
'error': f'model {model.id} is in {model.job_status.value} state, not valid to be used for train'
|
|
}
|
|
|
|
base_key = f'{_type}/train/{model.name}/{request_id}'
|
|
input_location = f'{base_key}/input'
|
|
presign_url_map = None
|
|
if event.filenames is None:
|
|
# Invoked from api, no config file is defined in the parameters
|
|
json_file_name = 'db_config_cloud.json'
|
|
tar_file_name = 'db_config.tar'
|
|
tar_file_content = f'/tmp/models/sagemaker_dreambooth/{model.name}'
|
|
tar_file_path = f'/tmp/{tar_file_name}'
|
|
|
|
db_config_json = load_json_from_s3(bucket_name, 'template/' + json_file_name)
|
|
# Merge user parameter, if no config_params is defined, use the default value in S3 bucket
|
|
if "config_params" in event.params:
|
|
db_config_json.update(event.params["config_params"])
|
|
|
|
# Add model parameters into train params
|
|
event.params["training_params"]["model_name"] = model.name
|
|
event.params["training_params"]["model_type"] = model.model_type
|
|
event.params["training_params"]["s3_model_path"] = model.output_s3_location
|
|
|
|
# Upload the merged JSON string to the S3 bucket as a tar file
|
|
try:
|
|
if not os.path.exists(tar_file_content):
|
|
os.makedirs(tar_file_content)
|
|
saved_path = save_json_to_file(db_config_json, tar_file_content, json_file_name)
|
|
print(f'file saved to {saved_path}')
|
|
with tarfile.open('/tmp/' + tar_file_name, 'w') as tar:
|
|
# Add the contents of 'models' directory to the tar file without including the /tmp itself
|
|
tar.add(tar_file_content, arcname=f'models/sagemaker_dreambooth/{model.name}')
|
|
|
|
s3 = boto3.client('s3')
|
|
s3.upload_file(tar_file_path, bucket_name, os.path.join(input_location, tar_file_name))
|
|
logger.info(f"Tar file '{tar_file_name}' uploaded to '{bucket_name}' successfully.")
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error uploading JSON file to S3: {e}")
|
|
else:
|
|
presign_url_map = get_s3_presign_urls(bucket_name=bucket_name, base_key=input_location, filenames=event.filenames)
|
|
|
|
user_roles = get_user_roles(ddb_service, user_table, event.creator)
|
|
checkpoint = CheckPoint(
|
|
id=request_id,
|
|
checkpoint_type=event.train_type,
|
|
s3_location=f's3://{bucket_name}/{base_key}/output',
|
|
checkpoint_status=CheckPointStatus.Initial,
|
|
timestamp=datetime.datetime.now().timestamp(),
|
|
allowed_roles_or_users=user_roles
|
|
)
|
|
ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__)
|
|
train_input_s3_location = f's3://{bucket_name}/{input_location}'
|
|
|
|
train_job = TrainJob(
|
|
id=request_id,
|
|
model_id=event.model_id,
|
|
job_status=TrainJobStatus.Initial,
|
|
params=event.params,
|
|
train_type=event.train_type,
|
|
input_s3_location=train_input_s3_location,
|
|
checkpoint_id=checkpoint.id,
|
|
timestamp=datetime.datetime.now().timestamp(),
|
|
allowed_roles_or_users=[event.creator]
|
|
)
|
|
ddb_service.put_items(table=train_table, entries=train_job.__dict__)
|
|
|
|
return {
|
|
'statusCode': 200,
|
|
'job': {
|
|
'id': train_job.id,
|
|
'status': train_job.job_status.value,
|
|
'trainType': train_job.train_type,
|
|
'params': train_job.params,
|
|
'input_location': train_input_s3_location,
|
|
},
|
|
's3PresignUrl': presign_url_map
|
|
}
|
|
except Exception as e:
|
|
logger.error(e)
|
|
return {
|
|
'statusCode': 200,
|
|
'error': str(e)
|
|
}
|
|
|
|
|
|
# GET /trains
|
|
def list_all_train_jobs_api(event, context):
|
|
_filter = {}
|
|
|
|
parameters = event['queryStringParameters']
|
|
if parameters:
|
|
if 'types' in parameters and len(parameters['types']) > 0:
|
|
_filter['train_type'] = parameters['types']
|
|
|
|
if 'status' in parameters and len(parameters['status']) > 0:
|
|
_filter['job_status'] = parameters['status']
|
|
|
|
resp = ddb_service.scan(table=train_table, filters=_filter)
|
|
if resp is None or len(resp) == 0:
|
|
return ok(data={'trainJobs': []})
|
|
|
|
requestor_name = event['requestContext']['authorizer']['username']
|
|
try:
|
|
requestor_permissions = get_permissions_by_username(ddb_service, user_table, requestor_name)
|
|
requestor_roles = get_user_roles(ddb_service=ddb_service, user_table_name=user_table, username=requestor_name)
|
|
if 'train' not in requestor_permissions or \
|
|
('all' not in requestor_permissions['train'] and 'list' not in requestor_permissions['train']):
|
|
return bad_request(message='user has no permission to train')
|
|
|
|
train_jobs = []
|
|
for tr in resp:
|
|
train_job = TrainJob(**(ddb_service.deserialize(tr)))
|
|
model_name = 'not_applied'
|
|
if 'training_params' in train_job.params and 'model_name' in train_job.params['training_params']:
|
|
model_name = train_job.params['training_params']['model_name']
|
|
|
|
train_job_dto = {
|
|
'id': train_job.id,
|
|
'modelName': model_name,
|
|
'status': train_job.job_status.value,
|
|
'trainType': train_job.train_type,
|
|
'created': train_job.timestamp,
|
|
'sagemakerTrainName': train_job.sagemaker_train_name,
|
|
}
|
|
if train_job.allowed_roles_or_users and check_user_permissions(train_job.allowed_roles_or_users, requestor_roles, requestor_name):
|
|
train_jobs.append(train_job_dto)
|
|
elif not train_job.allowed_roles_or_users and \
|
|
'user' in requestor_permissions and \
|
|
'all' in requestor_permissions['user']:
|
|
# superuser can view the legacy data
|
|
train_jobs.append(train_job_dto)
|
|
|
|
return ok(data={'trainJobs': train_jobs}, decimal=True)
|
|
except Exception as e:
|
|
logger.error(e)
|
|
return bad_request(message=str(e))
|
|
|
|
|
|
# PUT /train used to kickoff a train job step function
|
|
def update_train_job_api(event, context):
|
|
if 'status' in event and 'train_job_id' in event and event['status'] == TrainJobStatus.Training.value:
|
|
return _start_train_job(event['train_job_id'])
|
|
|
|
return {
|
|
'statusCode': 200,
|
|
'msg': f'not implemented for train job status {event["status"]}'
|
|
}
|
|
|
|
|
|
def _start_train_job(train_job_id: str):
|
|
raw_train_job = ddb_service.get_item(table=train_table, key_values={
|
|
'id': train_job_id
|
|
})
|
|
if raw_train_job is None or len(raw_train_job) == 0:
|
|
return {
|
|
'statusCode': 500,
|
|
'error': f'no such train job with id({train_job_id})'
|
|
}
|
|
|
|
train_job = TrainJob(**raw_train_job)
|
|
|
|
model_raw = ddb_service.get_item(table=model_table, key_values={
|
|
'id': train_job.model_id
|
|
})
|
|
if model_raw is None:
|
|
return {
|
|
'statusCode': 500,
|
|
'error': f'model with id {train_job.model_id} is not found'
|
|
}
|
|
|
|
model = Model(**model_raw)
|
|
|
|
raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={
|
|
'id': train_job.checkpoint_id
|
|
})
|
|
if raw_checkpoint is None:
|
|
return {
|
|
'statusCode': 500,
|
|
'error': f'checkpoint with id {train_job.checkpoint_id} is not found'
|
|
}
|
|
|
|
checkpoint = CheckPoint(**raw_checkpoint)
|
|
|
|
try:
|
|
# JSON encode hyperparameters
|
|
def json_encode_hyperparameters(hyperparameters):
|
|
new_params = {}
|
|
for k, v in hyperparameters.items():
|
|
json_v = json.dumps(v, cls=DecimalEncoder)
|
|
v_bytes = json_v.encode('ascii')
|
|
base64_bytes = base64.b64encode(v_bytes)
|
|
base64_v = base64_bytes.decode('ascii')
|
|
new_params[k] = base64_v
|
|
return new_params
|
|
|
|
hyperparameters = json_encode_hyperparameters({
|
|
"sagemaker_program": "extensions/sd-webui-sagemaker/sagemaker_entrypoint_json.py",
|
|
"params": train_job.params,
|
|
"s3-input-path": train_job.input_s3_location,
|
|
"s3-output-path": checkpoint.s3_location,
|
|
})
|
|
|
|
final_instance_type = instance_type
|
|
if 'training_params' in train_job.params \
|
|
and 'training_instance_type' in train_job.params['training_params'] and \
|
|
train_job.params['training_params']['training_instance_type']:
|
|
final_instance_type = train_job.params['training_params']['training_instance_type']
|
|
|
|
est = sagemaker.estimator.Estimator(
|
|
image_uri,
|
|
sagemaker_role_arn,
|
|
instance_count=1,
|
|
instance_type=final_instance_type,
|
|
volume_size=125,
|
|
base_job_name=f'{model.name}',
|
|
hyperparameters=hyperparameters,
|
|
job_id=train_job.id,
|
|
)
|
|
est.fit(wait=False)
|
|
|
|
while not est._current_job_name:
|
|
time.sleep(1)
|
|
|
|
train_job.sagemaker_train_name = est._current_job_name
|
|
# trigger stepfunction
|
|
stepfunctions_client = StepFunctionUtilsService(logger=logger)
|
|
sfn_input = {
|
|
'train_job_id': train_job.id,
|
|
'train_job_name': train_job.sagemaker_train_name
|
|
}
|
|
sfn_arn = stepfunctions_client.invoke_step_function(training_stepfunction_arn, sfn_input)
|
|
# todo: use batch update, this is ugly!!!
|
|
search_key = {'id': train_job.id}
|
|
ddb_service.update_item(
|
|
table=train_table,
|
|
key=search_key,
|
|
field_name='sagemaker_train_name',
|
|
value=est._current_job_name
|
|
)
|
|
train_job.job_status = TrainJobStatus.Training
|
|
ddb_service.update_item(
|
|
table=train_table,
|
|
key=search_key,
|
|
field_name='job_status',
|
|
value=TrainJobStatus.Training.value
|
|
)
|
|
train_job.sagemaker_sfn_arn = sfn_arn
|
|
ddb_service.update_item(
|
|
table=train_table,
|
|
key=search_key,
|
|
field_name='sagemaker_sfn_arn',
|
|
value=sfn_arn
|
|
)
|
|
|
|
return {
|
|
'statusCode': 200,
|
|
'job': {
|
|
'id': train_job.id,
|
|
'status': train_job.job_status.value,
|
|
'created': train_job.timestamp,
|
|
'trainType': train_job.train_type,
|
|
'params': train_job.params,
|
|
'input_location': train_job.input_s3_location
|
|
},
|
|
}
|
|
except Exception as e:
|
|
print(e)
|
|
return {
|
|
'statusCode': 500,
|
|
'error': str(e)
|
|
}
|
|
|
|
|
|
# sfn
|
|
def check_train_job_status(event, context):
|
|
import boto3
|
|
boto3_sagemaker = boto3.client('sagemaker')
|
|
train_job_name = event['train_job_name']
|
|
train_job_id = event['train_job_id']
|
|
|
|
resp = boto3_sagemaker.describe_training_job(
|
|
TrainingJobName=train_job_name
|
|
)
|
|
|
|
training_job_status = resp['TrainingJobStatus']
|
|
event['status'] = training_job_status
|
|
|
|
raw_train_job = ddb_service.get_item(table=train_table, key_values={
|
|
'id': train_job_id,
|
|
})
|
|
|
|
if raw_train_job is None or len(raw_train_job) == 0:
|
|
event['status'] = 'Failed'
|
|
return {
|
|
'statusCode': 500,
|
|
'msg': f'no such training job find in ddb id[{train_job_id}]'
|
|
}
|
|
|
|
training_job = TrainJob(**raw_train_job)
|
|
if training_job_status == 'InProgress' or training_job_status == 'Stopping':
|
|
return event
|
|
|
|
if training_job_status == 'Failed' or training_job_status == 'Stopped':
|
|
training_job.job_status = TrainJobStatus.Fail
|
|
if 'FailureReason' in resp:
|
|
err_msg = resp['FailureReason']
|
|
training_job.params['resp'] = {
|
|
'status': 'Failed',
|
|
'error_msg': err_msg,
|
|
'raw_resp': resp
|
|
}
|
|
|
|
if training_job_status == 'Completed':
|
|
training_job.job_status = TrainJobStatus.Complete
|
|
# todo: update checkpoints
|
|
raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={
|
|
'id': training_job.checkpoint_id
|
|
})
|
|
if raw_checkpoint is None or len(raw_checkpoint) == 0:
|
|
# todo: or create new one
|
|
return 'failed because no checkpoint, not normal'
|
|
|
|
checkpoint = CheckPoint(**raw_checkpoint)
|
|
checkpoint.checkpoint_status = CheckPointStatus.Active
|
|
s3 = boto3.client('s3')
|
|
bucket, key = split_s3_path(checkpoint.s3_location)
|
|
s3_resp = s3.list_objects(
|
|
Bucket=bucket,
|
|
Prefix=key,
|
|
)
|
|
checkpoint.checkpoint_names = []
|
|
if 'Contents' in s3_resp and len(s3_resp['Contents']) > 0:
|
|
for obj in s3_resp['Contents']:
|
|
checkpoint_name = obj['Key'].replace(f'{key}/', "")
|
|
checkpoint.checkpoint_names.append(checkpoint_name)
|
|
else:
|
|
training_job.job_status = TrainJobStatus.Fail
|
|
checkpoint.checkpoint_status = CheckPointStatus.Initial
|
|
|
|
ddb_service.update_item(
|
|
table=checkpoint_table,
|
|
key={
|
|
'id': checkpoint.id
|
|
},
|
|
field_name='checkpoint_status',
|
|
value=checkpoint.checkpoint_status.value
|
|
)
|
|
ddb_service.update_item(
|
|
table=checkpoint_table,
|
|
key={
|
|
'id': checkpoint.id
|
|
},
|
|
field_name='checkpoint_names',
|
|
value=checkpoint.checkpoint_names
|
|
)
|
|
|
|
training_job.params['resp'] = {
|
|
'raw_resp': resp
|
|
}
|
|
|
|
# fixme: this is ugly
|
|
ddb_service.update_item(
|
|
table=train_table,
|
|
key={'id': training_job.id},
|
|
field_name='job_status',
|
|
value=training_job.job_status.value
|
|
)
|
|
|
|
ddb_service.update_item(
|
|
table=train_table,
|
|
key={'id': training_job.id},
|
|
field_name='params',
|
|
value=training_job.params
|
|
)
|
|
|
|
return event
|
|
|
|
|
|
# sfn
|
|
def process_train_job_result(event, context):
|
|
train_job_id = event['train_job_id']
|
|
|
|
raw_train_job = ddb_service.get_item(table=train_table, key_values={
|
|
'id': train_job_id,
|
|
})
|
|
|
|
if raw_train_job is None or len(raw_train_job) == 0:
|
|
return {
|
|
'statusCode': 500,
|
|
'msg': f'no such training job find in ddb id[{train_job_id}]'
|
|
}
|
|
|
|
train_job = TrainJob(**raw_train_job)
|
|
|
|
publish_msg(
|
|
topic_arn=user_topic_arn,
|
|
subject=f'Create Model Job {train_job.sagemaker_train_name} {train_job.job_status}',
|
|
msg=f'to be done with resp: \n {train_job.job_status}'
|
|
) # todo: find out msg
|
|
|
|
return 'job completed'
|