add checkpoints api

pull/800/head
Xiujuan Li 2024-05-21 12:37:01 +08:00
parent 8ee2b019f9
commit ae73dbf9ea
6 changed files with 49 additions and 9 deletions

View File

@ -350,6 +350,12 @@ export class CreateCheckPointApi {
},
},
},
target_path: {
type: JsonSchemaType.STRING,
},
source_path: {
type: JsonSchemaType.STRING,
},
},
required: [
'checkpoint_type',

View File

@ -4,13 +4,13 @@ import logging
import os
import urllib.parse
from dataclasses import dataclass
from typing import Any
from typing import Any, Optional
import boto3
import requests
from aws_lambda_powertools import Tracer
from common.const import PERMISSION_CHECKPOINT_ALL, PERMISSION_CHECKPOINT_CREATE
from common.const import PERMISSION_CHECKPOINT_ALL, PERMISSION_CHECKPOINT_CREATE, COMFY_TYPE
from common.ddb_service.client import DynamoDbUtilsService
from common.response import bad_request, created, accepted
from libs.common_tools import get_base_checkpoint_s3_key, \
@ -39,6 +39,8 @@ class CreateCheckPointEvent:
params: dict[str, Any]
filenames: [MultipartFileReq] = None
urls: [str] = None
target_path: Optional[str] = None
source_path: Optional[str] = None
@tracer.capture_lambda_handler
@ -58,7 +60,14 @@ def handler(raw_event, context):
_type = event.checkpoint_type
base_key = get_base_checkpoint_s3_key(_type, 'custom', request_id)
if _type == COMFY_TYPE:
if not event.source_path or not event.target_path:
return bad_request(message='Please check your source_path or target_path of the checkpoints')
#such as source_path :"comfy/{comfy_endpoint}/{prepare_version}/models/" target_path:"models/checkpoints"
base_key = event.source_path
else:
base_key = get_base_checkpoint_s3_key(_type, 'custom', request_id)
presign_url_map = batch_get_s3_multipart_signed_urls(
bucket_name=bucket_name,
base_key=base_key,
@ -101,7 +110,9 @@ def handler(raw_event, context):
checkpoint_status=CheckPointStatus.Initial,
params=checkpoint_params,
timestamp=datetime.datetime.now().timestamp(),
allowed_roles_or_users=user_roles
allowed_roles_or_users=user_roles,
source_path=event.source_path,
target_path=event.target_path
)
ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__)
data = {
@ -110,7 +121,9 @@ def handler(raw_event, context):
'type': _type,
's3_location': checkpoint.s3_location,
'status': checkpoint.checkpoint_status.value,
'params': checkpoint.params
'params': checkpoint.params,
'source_path': checkpoint.source_path,
'target_path': checkpoint.target_path,
},
's3PresignUrl': multiparts_resp
}
@ -130,6 +143,8 @@ def invoke_url_lambda(event: CreateCheckPointEvent):
'checkpoint_type': event.checkpoint_type,
'params': event.params,
'url': url,
'source_path': event.source_path,
'target_path': event.target_path,
})
)
logger.info(resp)

View File

@ -87,7 +87,9 @@ def handler(event, context):
'name': ckpt.checkpoint_names,
'created': ckpt.timestamp,
'params': ckpt.params,
'allowed_roles_or_users': ckpt.allowed_roles_or_users
'allowed_roles_or_users': ckpt.allowed_roles_or_users,
'source_path': ckpt.source_path,
'target_path': ckpt.target_path,
})
ckpts = sort_checkpoints(ckpts)

View File

@ -5,13 +5,14 @@ import logging
import os
import urllib.parse
from dataclasses import dataclass
from typing import Any
from typing import Any, Optional
import requests
from aws_lambda_powertools import Tracer
from common.ddb_service.client import DynamoDbUtilsService
from common.response import bad_request, forbidden
from const import COMFY_TYPE
from libs.common_tools import get_base_checkpoint_s3_key, multipart_upload_from_url
from libs.data_types import CheckPoint, CheckPointStatus
from libs.utils import get_user_roles, get_permissions_by_username
@ -69,6 +70,8 @@ class CreateCheckPointByUrlEvent:
checkpoint_type: str
params: dict[str, Any]
url: str
source_path: Optional[str] = None
target_path: Optional[str] = None
@tracer.capture_lambda_handler
@ -78,7 +81,13 @@ def handler(raw_event, context):
request_id = context.aws_request_id
event = CreateCheckPointByUrlEvent(**raw_event)
base_key = get_base_checkpoint_s3_key(event.checkpoint_type, 'custom', request_id)
if event.checkpoint_type == COMFY_TYPE:
if not event.source_path or not event.target_path:
return bad_request(message='Please check your source_path or target_path of the checkpoints')
# such as source_path :"comfy/{comfy_endpoint}/{prepare_version}/models/" target_path:"models/checkpoints"
base_key = event.source_path
else:
base_key = get_base_checkpoint_s3_key(event.checkpoint_type, 'custom', request_id)
file_names = []
logger.info(f"start to upload model:{event.url}")
checkpoint_params = {}
@ -111,6 +120,8 @@ def handler(raw_event, context):
params=checkpoint_params,
timestamp=datetime.datetime.now().timestamp(),
allowed_roles_or_users=user_roles,
source_path=event.source_path,
target_path=event.target_path,
)
ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__)
logger.info("finished insert item to ddb")
@ -120,7 +131,9 @@ def handler(raw_event, context):
'type': event.checkpoint_type,
's3_location': checkpoint.s3_location,
'status': checkpoint.checkpoint_status.value,
'params': checkpoint.params
'params': checkpoint.params,
'source_path': checkpoint.source_path,
'target_path': checkpoint.target_path,
}
}
logger.info(data)

View File

@ -26,6 +26,8 @@ KOHYA_XL_TOML_FILE_NAME = 'kohya_config_cloud_xl.toml'
KOHYA_MODEL_ID = 'kohya'
TRAIN_TYPE = "Stable-diffusion"
COMFY_TYPE = 'Comfy'
PERMISSION_INFERENCE_ALL = "inference:all"
# todo will be remove, compatible with old data
PERMISSION_INFERENCE_LIST = "inference:list"

View File

@ -46,6 +46,8 @@ class CheckPoint:
version: str = 'v1.0' # todo: this is for the future
checkpoint_names: Optional[list[str]] = None # the actual checkpoint file names
params: Optional[dict[str, Any]] = None
source_path: Optional[str] = None
target_path: Optional[str] = None
def __post_init__(self):
if type(self.checkpoint_status) == str: