stable-diffusion-aws-extension/middleware_api/lambda/trainings/list_training_jobs.py

100 lines
3.4 KiB
Python

import json
import logging
import os
import boto3
from common.const import PERMISSION_TRAIN_ALL
from common.ddb_service.client import DynamoDbUtilsService
from common.response import ok
from common.util import get_query_param
from libs.data_types import TrainJob
from libs.utils import get_permissions_by_username, get_user_roles, check_user_permissions, permissions_check, \
response_error, decode_last_key, encode_last_key
train_table = os.environ.get('TRAIN_TABLE')
user_table = os.environ.get('MULTI_USER_TABLE')
ddb = boto3.resource('dynamodb')
table = ddb.Table(train_table)
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
ddb_service = DynamoDbUtilsService(logger=logger)
# GET /trains
def handler(event, context):
_filter = {}
try:
logger.info(json.dumps(event))
requestor_name = permissions_check(event, [PERMISSION_TRAIN_ALL])
exclusive_start_key = get_query_param(event, 'exclusive_start_key')
limit = int(get_query_param(event, 'limit', 10))
scan_kwargs = {
'Limit': limit,
}
if exclusive_start_key:
scan_kwargs['ExclusiveStartKey'] = decode_last_key(exclusive_start_key)
logger.info(scan_kwargs)
response = table.scan(**scan_kwargs)
logger.info(json.dumps(response, default=str))
items = response.get('Items', [])
last_evaluated_key = encode_last_key(response.get('LastEvaluatedKey'))
if items is None or len(items) == 0:
return ok(data={'trainings': []})
requestor_permissions = get_permissions_by_username(ddb_service, user_table, requestor_name)
requestor_roles = get_user_roles(ddb_service=ddb_service, user_table_name=user_table, username=requestor_name)
train_jobs = []
for row in items:
train_job = TrainJob(**row)
model_name = 'not_applied'
if 'training_params' in train_job.params and 'model_name' in train_job.params['training_params']:
model_name = train_job.params['training_params']['model_name']
train_job_dto = {
'id': train_job.id,
'modelName': model_name,
'status': train_job.job_status.value,
'trainType': train_job.train_type,
'created': train_job.timestamp,
'sagemakerTrainName': train_job.sagemaker_train_name,
'params': train_job.params,
}
if train_job.allowed_roles_or_users and check_user_permissions(train_job.allowed_roles_or_users,
requestor_roles, requestor_name):
train_jobs.append(train_job_dto)
elif not train_job.allowed_roles_or_users and \
'user' in requestor_permissions and \
'all' in requestor_permissions['user']:
# superuser can view the legacy data
train_jobs.append(train_job_dto)
train_jobs = sort_jobs(train_jobs)
data = {
'trainings': train_jobs,
'last_evaluated_key': last_evaluated_key
}
return ok(data=data, decimal=True)
except Exception as e:
return response_error(e)
def sort_jobs(train_jobs):
if len(train_jobs) == 0:
return train_jobs
return sorted(train_jobs, key=lambda x: x['created'], reverse=True)