701 lines
20 KiB
Python
701 lines
20 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from typing import List, Optional
|
|
|
|
import boto3
|
|
from aws_lambda_powertools import Tracer
|
|
from aws_lambda_powertools.utilities.typing import LambdaContext
|
|
|
|
from libs.common_tools import DecimalEncoder
|
|
from libs.utils import response_error
|
|
|
|
client = boto3.client('apigateway')
|
|
|
|
tracer = Tracer()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
|
|
|
|
esd_version = os.environ.get("ESD_VERSION")
|
|
|
|
|
|
@dataclass
|
|
class Schema:
|
|
type: str
|
|
default: Optional[str] = None
|
|
description: Optional[str] = None
|
|
|
|
def to_dict(self):
|
|
data = {
|
|
"type": self.type,
|
|
}
|
|
|
|
if self.default:
|
|
data["default"] = self.default
|
|
|
|
if self.description:
|
|
data["description"] = self.description
|
|
|
|
return data
|
|
|
|
|
|
@dataclass
|
|
class ExternalDocs:
|
|
url: str
|
|
description: str
|
|
|
|
def to_dict(self):
|
|
return {"url": self.url, "description": self.description}
|
|
|
|
|
|
@dataclass
|
|
class Tag:
|
|
name: str
|
|
description: str
|
|
externalDocs: Optional[ExternalDocs] = None
|
|
|
|
def to_dict(self):
|
|
if self.externalDocs:
|
|
return {"name": self.name, "description": self.description, "externalDocs": self.externalDocs.to_dict()}
|
|
return {"name": self.name, "description": self.description}
|
|
|
|
|
|
@dataclass
|
|
class Parameter:
|
|
name: str
|
|
description: str
|
|
location: str
|
|
required: bool = False
|
|
schema: Optional[Schema] = None
|
|
|
|
def to_dict(self):
|
|
if self.schema:
|
|
return {
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"in": self.location,
|
|
"required": self.required,
|
|
"schema": self.schema.to_dict(),
|
|
}
|
|
|
|
return {
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"in": self.location,
|
|
"required": self.required,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class APISchema:
|
|
summary: str
|
|
tags: List[str]
|
|
parameters: Optional[List[Parameter]] = field(default_factory=list)
|
|
description: str = ""
|
|
|
|
|
|
header_user_name = Parameter(
|
|
name="username",
|
|
description="Username",
|
|
location="header",
|
|
required=True,
|
|
schema=Schema(type="string", default="api")
|
|
)
|
|
|
|
path_id = Parameter(name="id", description="ID", location="path", required=True)
|
|
path_name = Parameter(name="name", description="Name", location="path", required=True)
|
|
path_dataset_name = Parameter(name="id", description="Dataset Name", location="path", required=True)
|
|
|
|
query_limit = Parameter(name="limit", description="Limit Per Page", location="query")
|
|
query_page = Parameter(name="page", description="Page Index", location="query")
|
|
query_per_page = Parameter(name="per_page", description="Limit Per Page", location="query")
|
|
query_exclusive_start_key = Parameter(name="exclusive_start_key", description="Exclusive Start Key", location="query")
|
|
|
|
tags = [
|
|
Tag(name="Service", description="Service API").to_dict(),
|
|
Tag(name="Roles", description="Manage Roles").to_dict(),
|
|
Tag(name="Users", description="Manage Users").to_dict(),
|
|
Tag(name="Endpoints", description="Manage Endpoints").to_dict(),
|
|
Tag(
|
|
name="Checkpoints",
|
|
description="Manage Checkpoints",
|
|
externalDocs=ExternalDocs(
|
|
url="https://awslabs.github.io/stable-diffusion-aws-extension/en/developer-guide/api_upload_ckpt/",
|
|
description="Upload Checkpoint Process")
|
|
).to_dict(),
|
|
Tag(
|
|
name="Inferences",
|
|
description="Manage Inferences",
|
|
externalDocs=ExternalDocs(
|
|
url="https://awslabs.github.io/stable-diffusion-aws-extension/en/developer-guide/api_inference_process/",
|
|
description="Inference Process")
|
|
).to_dict(),
|
|
Tag(name="Executes", description="Manage Executes").to_dict(),
|
|
Tag(name="Datasets", description="Manage Datasets").to_dict(),
|
|
Tag(name="Trainings", description="Manage Trainings").to_dict(),
|
|
Tag(name="Prepare", description="Sync files to Endpoint").to_dict(),
|
|
Tag(name="Sync", description="Sync Message from Endpoint").to_dict(),
|
|
Tag(name="Workflows", description="Manage Workflows").to_dict(),
|
|
Tag(name="Schemas", description="Manage Schemas").to_dict(),
|
|
Tag(name="Others", description="Others API").to_dict(),
|
|
]
|
|
|
|
operations = {
|
|
"RootAPI": APISchema(
|
|
summary="Root API",
|
|
tags=["Service"],
|
|
description="The Root API of ESD"
|
|
),
|
|
"Ping": APISchema(
|
|
summary="Ping API",
|
|
tags=["Service"],
|
|
description="The Ping API for Health Check"
|
|
),
|
|
"ListRoles": APISchema(
|
|
summary="List Roles",
|
|
tags=["Roles"],
|
|
description="List all roles",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"GetInferenceJob": APISchema(
|
|
summary="Get Inference Job",
|
|
tags=["Inferences"],
|
|
description="Get Inference Job",
|
|
parameters=[
|
|
header_user_name,
|
|
path_id
|
|
]
|
|
),
|
|
"CreateRole": APISchema(
|
|
summary="Create Role",
|
|
tags=["Roles"],
|
|
description="Create a new role",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"DeleteRoles": APISchema(
|
|
summary="Delete Roles",
|
|
tags=["Roles"],
|
|
description="Delete specify Roles",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"GetTraining": APISchema(
|
|
summary="Get Training",
|
|
tags=["Trainings"],
|
|
description="Get Training List",
|
|
parameters=[
|
|
header_user_name,
|
|
path_id
|
|
]
|
|
),
|
|
"ListCheckpoints": APISchema(
|
|
summary="List Checkpoints",
|
|
tags=["Checkpoints"],
|
|
description="List Checkpoints with Parameters",
|
|
parameters=[
|
|
header_user_name,
|
|
query_page,
|
|
query_per_page,
|
|
Parameter(name="username", description="Filter by username", location="query"),
|
|
]
|
|
),
|
|
"CreateCheckpoint": APISchema(
|
|
summary="Create Checkpoint",
|
|
tags=["Checkpoints"],
|
|
description="Create a new Checkpoint",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"DeleteCheckpoints": APISchema(
|
|
summary="Delete Checkpoints",
|
|
tags=["Checkpoints"],
|
|
description="Delete specify Checkpoints",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"StartInferences": APISchema(
|
|
summary="Start Inference Job",
|
|
tags=["Inferences"],
|
|
description="Start specify Inference Job by ID",
|
|
parameters=[
|
|
header_user_name,
|
|
path_id
|
|
]
|
|
),
|
|
"ListExecutes": APISchema(
|
|
summary="List Executes",
|
|
tags=["Executes"],
|
|
description="List Executes with Parameters",
|
|
parameters=[
|
|
header_user_name,
|
|
query_limit,
|
|
query_exclusive_start_key
|
|
]
|
|
),
|
|
"CreateExecute": APISchema(
|
|
summary="Create Execute",
|
|
tags=["Executes"],
|
|
description="Create a new Execute for Comfy",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"DeleteExecutes": APISchema(
|
|
summary="Delete Executes",
|
|
tags=["Executes"],
|
|
description="Delete specify Executes",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"MergeExecute": APISchema(
|
|
summary="Merge Executes",
|
|
tags=["Executes"],
|
|
description="Merge specify Executes",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"GetApiOAS": APISchema(
|
|
summary="Get OAS",
|
|
description="Get OAS json schema",
|
|
tags=["Service"],
|
|
),
|
|
"ListUsers": APISchema(
|
|
summary="List Users",
|
|
tags=["Users"],
|
|
description="List all users",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"CreateUser": APISchema(
|
|
summary="Create User",
|
|
tags=["Users"],
|
|
description="Create a new user",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"DeleteUsers": APISchema(
|
|
summary="Delete Users",
|
|
tags=["Users"],
|
|
description="Delete specify Users",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"ListTrainings": APISchema(
|
|
summary="List Trainings",
|
|
tags=["Trainings"],
|
|
description="List Trainings with Parameters",
|
|
parameters=[
|
|
header_user_name,
|
|
query_limit,
|
|
query_exclusive_start_key
|
|
]
|
|
),
|
|
"CreateTraining": APISchema(
|
|
summary="Create Training",
|
|
tags=["Trainings"],
|
|
description="Create a new Training Job",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"DeleteTrainings": APISchema(
|
|
summary="Delete Trainings",
|
|
tags=["Trainings"],
|
|
description="Delete specify Trainings",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"GetExecute": APISchema(
|
|
summary="Get Execute",
|
|
tags=["Executes"],
|
|
description="Get Execute by ID",
|
|
parameters=[
|
|
header_user_name,
|
|
path_id
|
|
]
|
|
),
|
|
"GetExecuteLogs": APISchema(
|
|
summary="Get Execute Logs",
|
|
tags=["Executes"],
|
|
description="Get Execute Logs by ID",
|
|
parameters=[
|
|
header_user_name,
|
|
path_id
|
|
]
|
|
),
|
|
"ListDatasets": APISchema(
|
|
summary="List Datasets",
|
|
tags=["Datasets"],
|
|
description="List Datasets with Parameters",
|
|
parameters=[
|
|
header_user_name,
|
|
query_limit,
|
|
query_exclusive_start_key
|
|
]
|
|
),
|
|
"CropDataset": APISchema(
|
|
summary="Create new Crop Dataset",
|
|
tags=["Datasets"],
|
|
description="Create new Crop Dataset",
|
|
parameters=[
|
|
header_user_name,
|
|
path_dataset_name
|
|
]
|
|
),
|
|
"GetDataset": APISchema(
|
|
summary="Get Dataset",
|
|
tags=["Datasets"],
|
|
description="Get Dataset by ID",
|
|
parameters=[
|
|
header_user_name,
|
|
path_dataset_name
|
|
]
|
|
),
|
|
"UpdateCheckpoint": APISchema(
|
|
summary="Update Checkpoint",
|
|
tags=["Checkpoints"],
|
|
description="Update Checkpoint by ID",
|
|
parameters=[
|
|
header_user_name,
|
|
path_dataset_name
|
|
]
|
|
),
|
|
"CreateDataset": APISchema(
|
|
summary="Create Dataset",
|
|
tags=["Datasets"],
|
|
description="Create a new Dataset",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"DeleteDatasets": APISchema(
|
|
summary="Delete Datasets",
|
|
tags=["Datasets"],
|
|
description="Delete specify Datasets",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"UpdateDataset": APISchema(
|
|
summary="Update Dataset",
|
|
tags=["Datasets"],
|
|
description="Update Dataset by ID",
|
|
parameters=[
|
|
header_user_name,
|
|
path_dataset_name
|
|
]
|
|
),
|
|
"ListInferences": APISchema(
|
|
summary="List Inferences",
|
|
tags=["Inferences"],
|
|
description="List Inferences with Parameters",
|
|
parameters=[
|
|
header_user_name,
|
|
query_limit,
|
|
query_exclusive_start_key,
|
|
Parameter(name="type", description="Inference task type: txt2img, img2img", location="query"),
|
|
]
|
|
),
|
|
"CreateInferenceJob": APISchema(
|
|
summary="Create Inference Job",
|
|
tags=["Inferences"],
|
|
description="Create a new Inference Job",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"DeleteInferenceJobs": APISchema(
|
|
summary="Delete Inference Jobs",
|
|
tags=["Inferences"],
|
|
description="Delete specify Inference Jobs",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"ListEndpoints": APISchema(
|
|
summary="List Endpoints",
|
|
tags=["Endpoints"],
|
|
description="List Endpoints with Parameters",
|
|
parameters=[
|
|
header_user_name,
|
|
query_limit,
|
|
query_exclusive_start_key
|
|
]
|
|
),
|
|
"CreateEndpoint": APISchema(
|
|
summary="Create Endpoint",
|
|
tags=["Endpoints"],
|
|
description="Create a new Endpoint",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"DeleteEndpoints": APISchema(
|
|
summary="Delete Endpoints",
|
|
tags=["Endpoints"],
|
|
description="Delete specify Endpoints",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"SyncMessage": APISchema(
|
|
summary="Sync Message",
|
|
tags=["Sync"],
|
|
description="Sync Message to Endpoint",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"GetSyncMessage": APISchema(
|
|
summary="Get Sync Message",
|
|
description="Get Sync Message from Endpoint",
|
|
tags=["Sync"],
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"CreatePrepare": APISchema(
|
|
summary="Create Prepare",
|
|
tags=["Prepare"],
|
|
description="Create a new Prepare",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"GetPrepare": APISchema(
|
|
summary="Get Prepare",
|
|
tags=["Prepare"],
|
|
description="Get Prepare by ID",
|
|
parameters=[
|
|
header_user_name
|
|
]
|
|
),
|
|
"CreateWorkflow": APISchema(
|
|
summary="Release new Workflow",
|
|
tags=["Workflows"],
|
|
description="Create a new Workflow",
|
|
),
|
|
"ListWorkflows": APISchema(
|
|
summary="List Workflows",
|
|
tags=["Workflows"],
|
|
description="List Workflows with Parameters",
|
|
),
|
|
"DeleteWorkflows": APISchema(
|
|
summary="Delete Workflows",
|
|
tags=["Workflows"],
|
|
description="Delete specify Workflows",
|
|
),
|
|
'GetWorkflow': APISchema(
|
|
summary="Get Workflow",
|
|
tags=["Workflows"],
|
|
description="Get Workflow by Name",
|
|
parameters=[
|
|
path_name
|
|
]
|
|
),
|
|
"CreateSchema": APISchema(
|
|
summary="Release new Schema",
|
|
tags=["Schemas"],
|
|
description="Create a new Schema",
|
|
),
|
|
"ListSchemas": APISchema(
|
|
summary="List Schemas",
|
|
tags=["Schemas"],
|
|
description="List Schemas with Parameters",
|
|
),
|
|
"DeleteSchemas": APISchema(
|
|
summary="Delete Schemas",
|
|
tags=["Schemas"],
|
|
description="Delete specify Schemas",
|
|
),
|
|
'GetSchema': APISchema(
|
|
summary="Get Schema",
|
|
tags=["Schemas"],
|
|
description="Get Schema by Name",
|
|
parameters=[
|
|
path_name
|
|
]
|
|
),
|
|
'UpdateSchema': APISchema(
|
|
summary="Update Schema",
|
|
tags=["Schemas"],
|
|
description="Update Schema by Name",
|
|
parameters=[
|
|
path_name
|
|
]
|
|
),
|
|
}
|
|
|
|
|
|
@tracer.capture_lambda_handler
|
|
def handler(event: dict, context: LambdaContext):
|
|
logger.info(f'event: {event}')
|
|
logger.info(f'ctx: {context}')
|
|
|
|
api_id = event['requestContext']['apiId']
|
|
|
|
try:
|
|
response = client.get_export(
|
|
restApiId=api_id,
|
|
stageName='prod',
|
|
exportType='oas30',
|
|
accepts='application/json',
|
|
# parameters={
|
|
# 'extensions': 'apigateway'
|
|
# }
|
|
)
|
|
|
|
oas = response['body'].read()
|
|
json_schema = json.loads(oas)
|
|
json_schema = replace_null(json_schema)
|
|
json_schema['info']['version'] = esd_version.split('-')[0]
|
|
|
|
json_schema['servers'] = [
|
|
{
|
|
"url": "https://{ApiId}.execute-api.{Region}.{Domain}/prod/",
|
|
"variables": {
|
|
"ApiId": {
|
|
"default": "xxxxxx"
|
|
},
|
|
"Region": {
|
|
"default": "ap-northeast-1"
|
|
},
|
|
"Domain": {
|
|
"default": "amazonaws.com"
|
|
},
|
|
}
|
|
}
|
|
]
|
|
|
|
json_schema['info']['license'] = {
|
|
"name": "Apache 2.0",
|
|
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
|
|
}
|
|
|
|
json_schema['info']['description'] = (
|
|
"This is a ESD Server based on the OpenAPI 3.0 specification. \n"
|
|
"Some useful links: \n"
|
|
"\n- [The ESD Repository](https://github.com/awslabs/stable-diffusion-aws-extension)"
|
|
"\n- [Implementation Guide](https://awslabs.github.io/stable-diffusion-aws-extension/en/)")
|
|
|
|
json_schema['tags'] = tags
|
|
|
|
for path in json_schema['paths']:
|
|
for method in json_schema['paths'][path]:
|
|
meta = supplement_schema(json_schema['paths'][path][method])
|
|
json_schema['paths'][path][method]['description'] = meta.description
|
|
json_schema['paths'][path][method]['summary'] = meta.summary
|
|
json_schema['paths'][path][method]['tags'] = meta.tags
|
|
json_schema['paths'][path][method]['parameters'] = merge_parameters(meta,
|
|
json_schema['paths'][path][method])
|
|
|
|
json_schema['paths'] = dict(sorted(json_schema['paths'].items(), key=lambda x: x[0]))
|
|
|
|
payload = {
|
|
'isBase64Encoded': False,
|
|
'statusCode': 200,
|
|
'headers': {
|
|
'Content-Type': 'application/json',
|
|
'Access-Control-Allow-Origin': '*',
|
|
'Access-Control-Allow-Headers': '*',
|
|
'Access-Control-Allow-Methods': '*',
|
|
'Access-Control-Allow-Credentials': True,
|
|
},
|
|
'body': json.dumps(json_schema, cls=DecimalEncoder, indent=2)
|
|
}
|
|
|
|
return payload
|
|
except Exception as e:
|
|
return response_error(e)
|
|
|
|
|
|
def merge_parameters(schema: APISchema, item: dict):
|
|
if not schema.parameters:
|
|
return []
|
|
|
|
if 'parameters' not in item or len(item['parameters']) == 0:
|
|
item['parameters'] = []
|
|
for param in schema.parameters:
|
|
item['parameters'].append(param.to_dict())
|
|
return item['parameters']
|
|
|
|
for param in schema.parameters:
|
|
|
|
update = False
|
|
for original_para in item['parameters']:
|
|
if param.name == original_para['name'] and param.location == original_para['in']:
|
|
update = True
|
|
original_para.update(param.to_dict())
|
|
|
|
if update is False:
|
|
item['parameters'].append(param.to_dict())
|
|
|
|
return item['parameters']
|
|
|
|
|
|
def replace_null(data):
|
|
if isinstance(data, dict):
|
|
for key, value in data.items():
|
|
if value is None:
|
|
data[key] = {
|
|
"type": "null",
|
|
"description": "Last Key for Pagination"
|
|
}
|
|
else:
|
|
data[key] = replace_null(value)
|
|
elif isinstance(data, list):
|
|
for i, item in enumerate(data):
|
|
if item is None:
|
|
data[i] = {
|
|
"type": "null",
|
|
"description": "Last Key for Pagination"
|
|
}
|
|
else:
|
|
data[i] = replace_null(item)
|
|
|
|
return data
|
|
|
|
|
|
def supplement_schema(method: any):
|
|
if 'operationId' in method:
|
|
if method['operationId'] in operations:
|
|
item: APISchema = operations[method['operationId']]
|
|
if item.parameters:
|
|
parameters = item.parameters
|
|
else:
|
|
parameters = []
|
|
|
|
return APISchema(
|
|
summary=item.summary + f" ({method['operationId']})",
|
|
tags=item.tags,
|
|
description=item.description,
|
|
parameters=parameters
|
|
)
|
|
|
|
return APISchema(
|
|
summary=method['operationId'],
|
|
tags=["Others"],
|
|
parameters=[]
|
|
)
|
|
|
|
return APISchema(
|
|
summary="",
|
|
tags=["Others"],
|
|
parameters=[]
|
|
)
|