stable-diffusion-aws-extension/middleware_api/lambda/model_and_train/checkpoint_api.py

354 lines
13 KiB
Python

import datetime
import logging
import os
from dataclasses import dataclass
from typing import Any, Dict
import urllib.parse
import concurrent.futures
import requests
import json
from _types import CheckPoint, CheckPointStatus, MultipartFileReq
from common.ddb_service.client import DynamoDbUtilsService
from common.response import ok, bad_request, internal_server_error
from common_tools import get_base_checkpoint_s3_key, \
batch_get_s3_multipart_signed_urls, complete_multipart_upload, multipart_upload_from_url
from multi_users._types import PARTITION_KEYS, Role
from multi_users.utils import get_user_roles, check_user_permissions, get_permissions_by_username
checkpoint_table = os.environ.get('CHECKPOINT_TABLE')
bucket_name = os.environ.get('S3_BUCKET')
checkpoint_type = ["Stable-diffusion", "embeddings", "Lora", "hypernetworks", "ControlNet", "VAE"]
user_table = os.environ.get('MULTI_USER_TABLE')
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".yaml"]
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
ddb_service = DynamoDbUtilsService(logger=logger)
MAX_WORKERS = 10
# GET /checkpoints?username=USER_NAME&types=value&status=value
def list_all_checkpoints_api(event, context):
logger.info(json.dumps(event))
_filter = {}
user_roles = ['*']
username = None
parameters = event['queryStringParameters']
if parameters:
if 'types' in parameters and len(parameters['types']) > 0:
_filter['checkpoint_type'] = parameters['types']
if 'status' in parameters and len(parameters['status']) > 0:
_filter['checkpoint_status'] = parameters['status']
# todo: support multi user fetch later
username = parameters['username'] if 'username' in parameters and parameters['username'] else None
if username:
user_roles = get_user_roles(ddb_service=ddb_service, user_table_name=user_table, username=username)
requestor_name = event['requestContext']['authorizer']['username']
requestor_permissions = get_permissions_by_username(ddb_service, user_table, requestor_name)
requestor_created_roles_rows = ddb_service.scan(table=user_table, filters={
'kind': PARTITION_KEYS.role,
'creator': requestor_name
})
for requestor_created_roles_row in requestor_created_roles_rows:
role = Role(**ddb_service.deserialize(requestor_created_roles_row))
user_roles.append(role.sort_key)
raw_ckpts = ddb_service.scan(table=checkpoint_table, filters=_filter)
if raw_ckpts is None or len(raw_ckpts) == 0:
data = {
'checkpoints': []
}
return ok(data=data)
ckpts = []
for r in raw_ckpts:
ckpt = CheckPoint(**(ddb_service.deserialize(r)))
if check_user_permissions(ckpt.allowed_roles_or_users, user_roles, username) or (
'user' in requestor_permissions and 'all' in requestor_permissions['user']
):
ckpts.append({
'id': ckpt.id,
's3Location': ckpt.s3_location,
'type': ckpt.checkpoint_type,
'status': ckpt.checkpoint_status.value,
'name': ckpt.checkpoint_names,
'created': ckpt.timestamp,
'allowed_roles_or_users': ckpt.allowed_roles_or_users
})
data = {
'checkpoints': ckpts
}
return ok(data=data, decimal=True)
@dataclass
class UploadCheckPointEvent:
checkpointType: str
modelUrl: list[str]
params: dict[str, Any]
def download_and_upload_models(url: str, base_key: str, file_names: list, multipart_upload: dict, cannot_download: list):
logger.info(f"download_and_upload_models: {url}, {base_key}, {file_names}")
filename = ""
response = requests.get(url, allow_redirects=False, stream=True)
if response and response.status_code == 307:
if response.headers and 'Location' in response.headers:
url = response.headers.get('Location')
parsed_url = urllib.parse.urlparse(url)
filename = os.path.basename(parsed_url.path)
if os.path.splitext(filename)[1] not in CN_MODEL_EXTS:
logger.info(f"download_and_upload_models file error url:{url}, filename:{filename}")
cannot_download.append(url)
return
logger.info(f"file name is :{filename}")
file_names.append(filename)
s3_key = f'{base_key}/{filename}'
logger.info(f"upload s3 key is :{filename}")
multipart_upload[filename] = multipart_upload_from_url(url, bucket_name, s3_key)
# 并发上传文件
def concurrent_upload(file_urls, base_key, file_names, multipart_upload):
cannot_download = []
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures = []
for file_url in file_urls:
futures.append(
executor.submit(download_and_upload_models, file_url, base_key, file_names, multipart_upload,
cannot_download))
for future in concurrent.futures.as_completed(futures):
future.result()
if cannot_download:
return cannot_download
return None
# POST /upload_checkpoint
def upload_checkpoint_api(raw_event, context):
request_id = context.aws_request_id
event = UploadCheckPointEvent(**raw_event)
_type = event.checkpointType
headers = {
'Access-Control-Allow-Headers': 'Content-Type',
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'OPTIONS,POST,GET'
}
if _type not in checkpoint_type:
logger.info(f"type error:{_type}")
return {
'statusCode': 500,
'headers': headers,
'error': "please choose the right type from :"
"'Stable-diffusion', 'embeddings', 'Lora', 'hypernetworks', 'Controlnet', 'VAE'"
}
try:
base_key = get_base_checkpoint_s3_key(_type, 'custom', request_id)
urls = event.modelUrl
file_names = []
logger.info(f"start to upload models:{urls}")
checkpoint_params = {}
if event.params is not None and len(event.params) > 0:
checkpoint_params = event.params
checkpoint_params['created'] = str(datetime.datetime.now())
checkpoint_params['multipart_upload'] = {}
user_roles = ['*']
creator_permissions = {}
if 'creator' in event.params and event.params['creator']:
user_roles = get_user_roles(ddb_service, user_table, event.params['creator'])
creator_permissions = get_permissions_by_username(ddb_service, user_table, event.params['creator'])
if 'checkpoint' not in creator_permissions or \
('all' not in creator_permissions['checkpoint'] and 'create' not in creator_permissions['checkpoint']):
return {
'statusCode': 400,
'headers': headers,
'error': f"user has no permissions to create a model"
}
cannot_download = concurrent_upload(urls, base_key, file_names, checkpoint_params['multipart_upload'])
if cannot_download:
return {
'statusCode': 500,
'headers': headers,
'error': f"contains invalid urls:{cannot_download}"
}
logger.info("finished upload, prepare to insert item to ddb")
checkpoint = CheckPoint(
id=request_id,
checkpoint_type=_type,
s3_location=f's3://{bucket_name}/{base_key}',
checkpoint_names=file_names,
checkpoint_status=CheckPointStatus.Active,
params=checkpoint_params,
timestamp=datetime.datetime.now().timestamp(),
allowed_roles_or_users=user_roles,
)
ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__)
logger.info("finished insert item to ddb")
return {
'statusCode': 200,
'headers': headers,
'checkpoint': {
'id': request_id,
'type': _type,
's3_location': checkpoint.s3_location,
'status': checkpoint.checkpoint_status.value,
'params': checkpoint.params
}
}
except Exception as e:
logger.error(e)
return {
'statusCode': 500,
'headers': headers,
'error': str(e)
}
@dataclass
class CreateCheckPointEvent:
checkpoint_type: str
filenames: [MultipartFileReq]
params: dict[str, Any]
# POST /checkpoint
def create_checkpoint_api(raw_event, context):
request_id = context.aws_request_id
event = CreateCheckPointEvent(**json.loads(raw_event['body']))
_type = event.checkpoint_type
headers = {
'Access-Control-Allow-Headers': 'Content-Type',
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'OPTIONS,POST,GET'
}
try:
base_key = get_base_checkpoint_s3_key(_type, 'custom', request_id)
presign_url_map = batch_get_s3_multipart_signed_urls(
bucket_name=bucket_name,
base_key=base_key,
filenames=event.filenames
)
checkpoint_params = {}
if event.params is not None and len(event.params) > 0:
checkpoint_params = event.params
checkpoint_params['created'] = str(datetime.datetime.now())
checkpoint_params['multipart_upload'] = {}
multiparts_resp = {}
for key, val in presign_url_map.items():
checkpoint_params['multipart_upload'][key] = {
'upload_id': val['upload_id'],
'bucket': val['bucket'],
'key': val['key'],
}
multiparts_resp[key] = val['s3_signed_urls']
filenames_only = []
for f in event.filenames:
file = MultipartFileReq(**f)
filenames_only.append(file.filename)
if len(filenames_only) == 0:
return bad_request(message='no checkpoint name (file names) detected', headers=headers)
user_roles = ['*']
creator_permissions = {}
if 'creator' in event.params and event.params['creator']:
user_roles = get_user_roles(ddb_service, user_table, event.params['creator'])
creator_permissions = get_permissions_by_username(ddb_service, user_table, event.params['creator'])
if 'checkpoint' not in creator_permissions or \
('all' not in creator_permissions['checkpoint'] and 'create' not in creator_permissions['checkpoint']):
return bad_request(message='user has no permissions to create a model', headers=headers)
checkpoint = CheckPoint(
id=request_id,
checkpoint_type=_type,
s3_location=f's3://{bucket_name}/{base_key}',
checkpoint_names=filenames_only,
checkpoint_status=CheckPointStatus.Initial,
params=checkpoint_params,
timestamp=datetime.datetime.now().timestamp(),
allowed_roles_or_users=user_roles
)
ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__)
data = {
'checkpoint': {
'id': request_id,
'type': _type,
's3_location': checkpoint.s3_location,
'status': checkpoint.checkpoint_status.value,
'params': checkpoint.params
},
's3PresignUrl': multiparts_resp
}
return ok(data=data, headers=headers)
except Exception as e:
logger.error(e)
return internal_server_error(headers=headers, message=str(e))
@dataclass
class UpdateCheckPointEvent:
status: str
multi_parts_tags: Dict[str, Any]
# PUT /checkpoint
def update_checkpoint_api(raw_event, context):
event = UpdateCheckPointEvent(**json.loads(raw_event['body']))
checkpoint_id = raw_event['pathParameters']['id']
headers = {
'Access-Control-Allow-Headers': 'Content-Type',
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'OPTIONS,POST,GET'
}
try:
raw_checkpoint = ddb_service.get_item(table=checkpoint_table, key_values={
'id': checkpoint_id
})
if raw_checkpoint is None or len(raw_checkpoint) == 0:
return bad_request(
message=f'checkpoint not found with id {checkpoint_id}',
headers=headers
)
checkpoint = CheckPoint(**raw_checkpoint)
new_status = CheckPointStatus[event.status]
complete_multipart_upload(checkpoint, event.multi_parts_tags)
# if complete part failed, then no update
ddb_service.update_item(
table=checkpoint_table,
key={
'id': checkpoint.id,
},
field_name='checkpoint_status',
value=new_status
)
data = {
'checkpoint': {
'id': checkpoint.id,
'type': checkpoint.checkpoint_type,
's3_location': checkpoint.s3_location,
'status': checkpoint.checkpoint_status.value,
'params': checkpoint.params
}
}
return ok(data=data, headers=headers)
except Exception as e:
logger.error(e)
return internal_server_error(headers=headers, message=str(e))