stable-diffusion-aws-extension/middleware_api/checkpoints/create_checkpoint.py

212 lines
6.9 KiB
Python

import datetime
import json
import logging
import os
import urllib.parse
from dataclasses import dataclass
from typing import Any, Optional
import boto3
import requests
from aws_lambda_powertools import Tracer
from common.const import PERMISSION_CHECKPOINT_ALL, PERMISSION_CHECKPOINT_CREATE, COMFY_TYPE
from common.ddb_service.client import DynamoDbUtilsService
from common.response import bad_request, created, accepted
from libs.common_tools import get_base_checkpoint_s3_key, \
batch_get_s3_multipart_signed_urls
from libs.data_types import CheckPoint, CheckPointStatus, MultipartFileReq
from libs.utils import get_user_roles, permissions_check, response_error
tracer = Tracer()
checkpoint_table = os.environ.get('CHECKPOINT_TABLE')
bucket_name = os.environ.get('S3_BUCKET_NAME')
checkpoint_type = ["Stable-diffusion", "embeddings", "Lora", "hypernetworks", "ControlNet", "VAE"]
user_table = os.environ.get('MULTI_USER_TABLE')
upload_by_url_lambda_name = os.environ.get('UPLOAD_BY_URL_LAMBDA_NAME')
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
ddb_service = DynamoDbUtilsService(logger=logger)
lambda_client = boto3.client('lambda')
@dataclass
class CreateCheckPointEvent:
checkpoint_type: str
params: dict[str, Any]
filenames: [MultipartFileReq] = None
urls: [str] = None
target_path: Optional[str] = None
source_path: Optional[str] = None
@tracer.capture_lambda_handler
def handler(raw_event, context):
try:
logger.info(json.dumps(raw_event))
request_id = context.aws_request_id
event = CreateCheckPointEvent(**json.loads(raw_event['body']))
username = permissions_check(raw_event, [PERMISSION_CHECKPOINT_ALL, PERMISSION_CHECKPOINT_CREATE])
# all urls or filenames must be passed check
check_filenames_unique(event)
if event.urls:
return invoke_url_lambda(event)
_type = event.checkpoint_type
if _type == COMFY_TYPE:
if not event.source_path or not event.target_path:
return bad_request(message='Please check your source_path or target_path of the checkpoints')
#such as source_path :"comfy/{comfy_endpoint}/{prepare_version}/models/" target_path:"models/checkpoints"
base_key = event.source_path
else:
base_key = get_base_checkpoint_s3_key(_type, 'custom', request_id)
presign_url_map = batch_get_s3_multipart_signed_urls(
bucket_name=bucket_name,
base_key=base_key,
filenames=event.filenames
)
checkpoint_params = {}
if event.params is not None and len(event.params) > 0:
checkpoint_params = event.params
checkpoint_params['created'] = str(datetime.datetime.now())
checkpoint_params['multipart_upload'] = {}
multiparts_resp = {}
for key, val in presign_url_map.items():
checkpoint_params['multipart_upload'][key] = {
'upload_id': val['upload_id'],
'bucket': val['bucket'],
'key': val['key'],
}
multiparts_resp[key] = val['s3_signed_urls']
filenames_only = []
for f in event.filenames:
file = MultipartFileReq(**f)
filenames_only.append(file.filename)
if len(filenames_only) == 0:
return bad_request(message='no checkpoint name (file names) detected')
user_roles = ['*']
if username:
checkpoint_params['creator'] = username
user_roles = get_user_roles(ddb_service, user_table, username)
checkpoint = CheckPoint(
id=request_id,
checkpoint_type=_type,
s3_location=f's3://{bucket_name}/{base_key}',
checkpoint_names=filenames_only,
checkpoint_status=CheckPointStatus.Initial,
params=checkpoint_params,
timestamp=datetime.datetime.now().timestamp(),
allowed_roles_or_users=user_roles,
source_path=event.source_path,
target_path=event.target_path
)
ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__)
data = {
'checkpoint': {
'id': request_id,
'type': _type,
's3_location': checkpoint.s3_location,
'status': checkpoint.checkpoint_status.value,
'params': checkpoint.params,
'source_path': checkpoint.source_path,
'target_path': checkpoint.target_path,
},
's3PresignUrl': multiparts_resp
}
return created(data=data)
except Exception as e:
return response_error(e)
def invoke_url_lambda(event: CreateCheckPointEvent):
urls = list(set(event.urls))
for url in urls:
resp = lambda_client.invoke(
FunctionName=upload_by_url_lambda_name,
InvocationType='Event',
Payload=json.dumps({
'checkpoint_type': event.checkpoint_type,
'params': event.params,
'url': url,
'source_path': event.source_path,
'target_path': event.target_path,
})
)
logger.info(resp)
return accepted(message='Checkpoint creation in progress, please check later')
@tracer.capture_method
def check_filenames_unique(event: CreateCheckPointEvent):
names = []
if event.filenames:
for file in event.filenames:
names.append(file['filename'])
if event.urls:
for url in event.urls:
url = get_real_url(url)
filename = get_download_file_name(url)
names.append(filename)
logger.info(f"names: {names}")
check_ckpt_name_unique(names)
def check_ckpt_name_unique(names: [str]):
if len(names) == 0:
return
ckpts = ddb_service.scan(table=checkpoint_table)
exists_names = []
for ckpt in ckpts:
if 'checkpoint_names' not in ckpt:
continue
if 'L' not in ckpt['checkpoint_names']:
continue
for name in ckpt['checkpoint_names']['L']:
exists_names.append(name['S'])
logger.info(json.dumps(exists_names))
for name in names:
if name.strip() in exists_names:
raise Exception(f'{name} already exists, '
f'please use another or rename/delete exists')
@tracer.capture_method
def get_real_url(url: str):
url = url.strip()
if url.startswith('https://civitai.com/api/download/models/'):
response = requests.get(url, allow_redirects=False)
else:
response = requests.head(url, allow_redirects=True, timeout=10)
if response and response.status_code == 307:
if response.headers and 'Location' in response.headers:
return response.headers.get('Location')
return url
def get_download_file_name(url: str):
parsed_url = urllib.parse.urlparse(url)
return os.path.basename(parsed_url.path)