stable-diffusion-aws-extension/middleware_api/lambda/trainings/create_training_job.py

142 lines
5.9 KiB
Python

import datetime
import json
import logging
import os
import tarfile
from dataclasses import dataclass
from typing import Any, List, Optional
import boto3
from common.ddb_service.client import DynamoDbUtilsService
from common.response import bad_request, not_found, forbidden, internal_server_error, created
from common.util import get_s3_presign_urls
from common.util import load_json_from_s3, save_json_to_file
from libs.data_types import TrainJob, TrainJobStatus, Model, CreateModelStatus, CheckPoint, CheckPointStatus
from libs.utils import get_permissions_by_username, get_user_roles
bucket_name = os.environ.get('S3_BUCKET')
train_table = os.environ.get('TRAIN_TABLE')
model_table = os.environ.get('MODEL_TABLE')
checkpoint_table = os.environ.get('CHECKPOINT_TABLE')
user_table = os.environ.get('MULTI_USER_TABLE')
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
ddb_service = DynamoDbUtilsService(logger=logger)
@dataclass
class Event:
train_type: str
model_id: str
params: dict[str, Any]
creator: str
filenames: Optional[List[str]] = None
def handler(raw_event, context):
request_id = context.aws_request_id
event = Event(**json.loads(raw_event['body']))
logger.info(json.dumps(json.loads(raw_event['body'])))
_type = event.train_type
try:
creator_permissions = get_permissions_by_username(ddb_service, user_table, event.creator)
if 'train' not in creator_permissions \
or ('all' not in creator_permissions['train'] and 'create' not in creator_permissions['train']):
return forbidden(message=f'user {event.creator} has not permission to create a train job')
model_raw = ddb_service.get_item(table=model_table, key_values={
'id': event.model_id
})
# if model is not found, model_raw is {}
if model_raw == {}:
return not_found(message=f'model with id {event.model_id} is not found')
model = Model(**model_raw)
if model.job_status != CreateModelStatus.Complete:
return bad_request(
message=f'model {model.id} is in {model.job_status.value} state, not valid to be used for train')
base_key = f'{_type}/train/{model.name}/{request_id}'
input_location = f'{base_key}/input'
presign_url_map = None
if event.filenames is None:
# Invoked from api, no config file is defined in the parameters
json_file_name = 'db_config_cloud.json'
tar_file_name = 'db_config.tar'
tar_file_content = f'/tmp/models/sagemaker_dreambooth/{model.name}'
tar_file_path = f'/tmp/{tar_file_name}'
db_config_json = load_json_from_s3(bucket_name, 'template/' + json_file_name)
# Merge user parameter, if no config_params is defined, use the default value in S3 bucket
if "config_params" in event.params:
db_config_json.update(event.params["config_params"])
# Add model parameters into train params
event.params["training_params"]["model_name"] = model.name
event.params["training_params"]["model_type"] = model.model_type
event.params["training_params"]["s3_model_path"] = model.output_s3_location
# Upload the merged JSON string to the S3 bucket as a tar file
try:
if not os.path.exists(tar_file_content):
os.makedirs(tar_file_content)
saved_path = save_json_to_file(db_config_json, tar_file_content, json_file_name)
print(f'file saved to {saved_path}')
with tarfile.open('/tmp/' + tar_file_name, 'w') as tar:
# Add the contents of 'models' directory to the tar file without including the /tmp itself
tar.add(tar_file_content, arcname=f'models/sagemaker_dreambooth/{model.name}')
s3 = boto3.client('s3')
s3.upload_file(tar_file_path, bucket_name, os.path.join(input_location, tar_file_name))
logger.info(f"Tar file '{tar_file_name}' uploaded to '{bucket_name}' successfully.")
except Exception as e:
raise RuntimeError(f"Error uploading JSON file to S3: {e}")
else:
presign_url_map = get_s3_presign_urls(bucket_name=bucket_name, base_key=input_location,
filenames=event.filenames)
user_roles = get_user_roles(ddb_service, user_table, event.creator)
checkpoint = CheckPoint(
id=request_id,
checkpoint_type=event.train_type,
s3_location=f's3://{bucket_name}/{base_key}/output',
checkpoint_status=CheckPointStatus.Initial,
timestamp=datetime.datetime.now().timestamp(),
allowed_roles_or_users=user_roles
)
ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__)
train_input_s3_location = f's3://{bucket_name}/{input_location}'
train_job = TrainJob(
id=request_id,
model_id=event.model_id,
job_status=TrainJobStatus.Initial,
params=event.params,
train_type=event.train_type,
input_s3_location=train_input_s3_location,
checkpoint_id=checkpoint.id,
timestamp=datetime.datetime.now().timestamp(),
allowed_roles_or_users=[event.creator]
)
ddb_service.put_items(table=train_table, entries=train_job.__dict__)
data = {
'job': {
'id': train_job.id,
'status': train_job.job_status.value,
'trainType': train_job.train_type,
'params': train_job.params,
'input_location': train_input_s3_location,
},
's3PresignUrl': presign_url_map
}
return created(data=data)
except Exception as e:
logger.error(e)
return internal_server_error(message=str(e))