add checkpoints api
parent
8ee2b019f9
commit
ae73dbf9ea
|
|
@ -350,6 +350,12 @@ export class CreateCheckPointApi {
|
|||
},
|
||||
},
|
||||
},
|
||||
target_path: {
|
||||
type: JsonSchemaType.STRING,
|
||||
},
|
||||
source_path: {
|
||||
type: JsonSchemaType.STRING,
|
||||
},
|
||||
},
|
||||
required: [
|
||||
'checkpoint_type',
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ import logging
|
|||
import os
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
from aws_lambda_powertools import Tracer
|
||||
|
||||
from common.const import PERMISSION_CHECKPOINT_ALL, PERMISSION_CHECKPOINT_CREATE
|
||||
from common.const import PERMISSION_CHECKPOINT_ALL, PERMISSION_CHECKPOINT_CREATE, COMFY_TYPE
|
||||
from common.ddb_service.client import DynamoDbUtilsService
|
||||
from common.response import bad_request, created, accepted
|
||||
from libs.common_tools import get_base_checkpoint_s3_key, \
|
||||
|
|
@ -39,6 +39,8 @@ class CreateCheckPointEvent:
|
|||
params: dict[str, Any]
|
||||
filenames: [MultipartFileReq] = None
|
||||
urls: [str] = None
|
||||
target_path: Optional[str] = None
|
||||
source_path: Optional[str] = None
|
||||
|
||||
|
||||
@tracer.capture_lambda_handler
|
||||
|
|
@ -58,7 +60,14 @@ def handler(raw_event, context):
|
|||
|
||||
_type = event.checkpoint_type
|
||||
|
||||
base_key = get_base_checkpoint_s3_key(_type, 'custom', request_id)
|
||||
if _type == COMFY_TYPE:
|
||||
if not event.source_path or not event.target_path:
|
||||
return bad_request(message='Please check your source_path or target_path of the checkpoints')
|
||||
#such as source_path :"comfy/{comfy_endpoint}/{prepare_version}/models/" target_path:"models/checkpoints"
|
||||
base_key = event.source_path
|
||||
else:
|
||||
base_key = get_base_checkpoint_s3_key(_type, 'custom', request_id)
|
||||
|
||||
presign_url_map = batch_get_s3_multipart_signed_urls(
|
||||
bucket_name=bucket_name,
|
||||
base_key=base_key,
|
||||
|
|
@ -101,7 +110,9 @@ def handler(raw_event, context):
|
|||
checkpoint_status=CheckPointStatus.Initial,
|
||||
params=checkpoint_params,
|
||||
timestamp=datetime.datetime.now().timestamp(),
|
||||
allowed_roles_or_users=user_roles
|
||||
allowed_roles_or_users=user_roles,
|
||||
source_path=event.source_path,
|
||||
target_path=event.target_path
|
||||
)
|
||||
ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__)
|
||||
data = {
|
||||
|
|
@ -110,7 +121,9 @@ def handler(raw_event, context):
|
|||
'type': _type,
|
||||
's3_location': checkpoint.s3_location,
|
||||
'status': checkpoint.checkpoint_status.value,
|
||||
'params': checkpoint.params
|
||||
'params': checkpoint.params,
|
||||
'source_path': checkpoint.source_path,
|
||||
'target_path': checkpoint.target_path,
|
||||
},
|
||||
's3PresignUrl': multiparts_resp
|
||||
}
|
||||
|
|
@ -130,6 +143,8 @@ def invoke_url_lambda(event: CreateCheckPointEvent):
|
|||
'checkpoint_type': event.checkpoint_type,
|
||||
'params': event.params,
|
||||
'url': url,
|
||||
'source_path': event.source_path,
|
||||
'target_path': event.target_path,
|
||||
})
|
||||
)
|
||||
logger.info(resp)
|
||||
|
|
|
|||
|
|
@ -87,7 +87,9 @@ def handler(event, context):
|
|||
'name': ckpt.checkpoint_names,
|
||||
'created': ckpt.timestamp,
|
||||
'params': ckpt.params,
|
||||
'allowed_roles_or_users': ckpt.allowed_roles_or_users
|
||||
'allowed_roles_or_users': ckpt.allowed_roles_or_users,
|
||||
'source_path': ckpt.source_path,
|
||||
'target_path': ckpt.target_path,
|
||||
})
|
||||
|
||||
ckpts = sort_checkpoints(ckpts)
|
||||
|
|
|
|||
|
|
@ -5,13 +5,14 @@ import logging
|
|||
import os
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from aws_lambda_powertools import Tracer
|
||||
|
||||
from common.ddb_service.client import DynamoDbUtilsService
|
||||
from common.response import bad_request, forbidden
|
||||
from const import COMFY_TYPE
|
||||
from libs.common_tools import get_base_checkpoint_s3_key, multipart_upload_from_url
|
||||
from libs.data_types import CheckPoint, CheckPointStatus
|
||||
from libs.utils import get_user_roles, get_permissions_by_username
|
||||
|
|
@ -69,6 +70,8 @@ class CreateCheckPointByUrlEvent:
|
|||
checkpoint_type: str
|
||||
params: dict[str, Any]
|
||||
url: str
|
||||
source_path: Optional[str] = None
|
||||
target_path: Optional[str] = None
|
||||
|
||||
|
||||
@tracer.capture_lambda_handler
|
||||
|
|
@ -78,7 +81,13 @@ def handler(raw_event, context):
|
|||
request_id = context.aws_request_id
|
||||
event = CreateCheckPointByUrlEvent(**raw_event)
|
||||
|
||||
base_key = get_base_checkpoint_s3_key(event.checkpoint_type, 'custom', request_id)
|
||||
if event.checkpoint_type == COMFY_TYPE:
|
||||
if not event.source_path or not event.target_path:
|
||||
return bad_request(message='Please check your source_path or target_path of the checkpoints')
|
||||
# such as source_path :"comfy/{comfy_endpoint}/{prepare_version}/models/" target_path:"models/checkpoints"
|
||||
base_key = event.source_path
|
||||
else:
|
||||
base_key = get_base_checkpoint_s3_key(event.checkpoint_type, 'custom', request_id)
|
||||
file_names = []
|
||||
logger.info(f"start to upload model:{event.url}")
|
||||
checkpoint_params = {}
|
||||
|
|
@ -111,6 +120,8 @@ def handler(raw_event, context):
|
|||
params=checkpoint_params,
|
||||
timestamp=datetime.datetime.now().timestamp(),
|
||||
allowed_roles_or_users=user_roles,
|
||||
source_path=event.source_path,
|
||||
target_path=event.target_path,
|
||||
)
|
||||
ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__)
|
||||
logger.info("finished insert item to ddb")
|
||||
|
|
@ -120,7 +131,9 @@ def handler(raw_event, context):
|
|||
'type': event.checkpoint_type,
|
||||
's3_location': checkpoint.s3_location,
|
||||
'status': checkpoint.checkpoint_status.value,
|
||||
'params': checkpoint.params
|
||||
'params': checkpoint.params,
|
||||
'source_path': checkpoint.source_path,
|
||||
'target_path': checkpoint.target_path,
|
||||
}
|
||||
}
|
||||
logger.info(data)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ KOHYA_XL_TOML_FILE_NAME = 'kohya_config_cloud_xl.toml'
|
|||
KOHYA_MODEL_ID = 'kohya'
|
||||
TRAIN_TYPE = "Stable-diffusion"
|
||||
|
||||
COMFY_TYPE = 'Comfy'
|
||||
|
||||
PERMISSION_INFERENCE_ALL = "inference:all"
|
||||
# todo will be remove, compatible with old data
|
||||
PERMISSION_INFERENCE_LIST = "inference:list"
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ class CheckPoint:
|
|||
version: str = 'v1.0' # todo: this is for the future
|
||||
checkpoint_names: Optional[list[str]] = None # the actual checkpoint file names
|
||||
params: Optional[dict[str, Any]] = None
|
||||
source_path: Optional[str] = None
|
||||
target_path: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if type(self.checkpoint_status) == str:
|
||||
|
|
|
|||
Loading…
Reference in New Issue