stable-diffusion-aws-extension/middleware_api/service/oas.py

656 lines
18 KiB
Python

import json
import logging
import os
from dataclasses import dataclass, field
from typing import List, Optional
import boto3
from aws_lambda_powertools import Tracer
from aws_lambda_powertools.utilities.typing import LambdaContext
from libs.common_tools import DecimalEncoder
from libs.utils import response_error
client = boto3.client('apigateway')
tracer = Tracer()
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
esd_version = os.environ.get("ESD_VERSION")
@dataclass
class Schema:
type: str
default: Optional[str] = None
description: Optional[str] = None
def to_dict(self):
data = {
"type": self.type,
}
if self.default:
data["default"] = self.default
if self.description:
data["description"] = self.description
return data
@dataclass
class ExternalDocs:
url: str
description: str
def to_dict(self):
return {"url": self.url, "description": self.description}
@dataclass
class Tag:
name: str
description: str
externalDocs: Optional[ExternalDocs] = None
def to_dict(self):
if self.externalDocs:
return {"name": self.name, "description": self.description, "externalDocs": self.externalDocs.to_dict()}
return {"name": self.name, "description": self.description}
@dataclass
class Parameter:
name: str
description: str
location: str
required: bool = False
schema: Optional[Schema] = None
def to_dict(self):
if self.schema:
return {
"name": self.name,
"description": self.description,
"in": self.location,
"required": self.required,
"schema": self.schema.to_dict(),
}
return {
"name": self.name,
"description": self.description,
"in": self.location,
"required": self.required,
}
@dataclass
class APISchema:
summary: str
tags: List[str]
parameters: Optional[List[Parameter]] = field(default_factory=list)
description: str = ""
header_user_name = Parameter(
name="username",
description="Username",
location="header",
required=True,
schema=Schema(type="string", default="api")
)
path_id = Parameter(name="id", description="ID", location="path", required=True)
path_dataset_name = Parameter(name="id", description="Dataset Name", location="path", required=True)
query_limit = Parameter(name="limit", description="Limit Per Page", location="query")
query_page = Parameter(name="page", description="Page Index", location="query")
query_per_page = Parameter(name="per_page", description="Limit Per Page", location="query")
query_exclusive_start_key = Parameter(name="exclusive_start_key", description="Exclusive Start Key", location="query")
tags = [
Tag(name="Service", description="Service API").to_dict(),
Tag(name="Roles", description="Manage Roles").to_dict(),
Tag(name="Users", description="Manage Users").to_dict(),
Tag(name="Endpoints", description="Manage Endpoints").to_dict(),
Tag(
name="Checkpoints",
description="Manage Checkpoints",
externalDocs=ExternalDocs(
url="https://awslabs.github.io/stable-diffusion-aws-extension/en/developer-guide/api_upload_ckpt/",
description="Upload Checkpoint Process")
).to_dict(),
Tag(
name="Inferences",
description="Manage Inferences",
externalDocs=ExternalDocs(
url="https://awslabs.github.io/stable-diffusion-aws-extension/en/developer-guide/api_inference_process/",
description="Inference Process")
).to_dict(),
Tag(name="Executes", description="Manage Executes").to_dict(),
Tag(name="Datasets", description="Manage Datasets").to_dict(),
Tag(name="Trainings", description="Manage Trainings").to_dict(),
Tag(name="Prepare", description="Sync files to Endpoint").to_dict(),
Tag(name="Sync", description="Sync Message from Endpoint").to_dict(),
Tag(name="Workflows", description="Manage Workflows").to_dict(),
Tag(name="Others", description="Others API").to_dict(),
]
operations = {
"RootAPI": APISchema(
summary="Root API",
tags=["Service"],
description="The Root API of ESD"
),
"Ping": APISchema(
summary="Ping API",
tags=["Service"],
description="The Ping API for Health Check"
),
"ListRoles": APISchema(
summary="List Roles",
tags=["Roles"],
description="List all roles",
parameters=[
header_user_name
]
),
"GetInferenceJob": APISchema(
summary="Get Inference Job",
tags=["Inferences"],
description="Get Inference Job",
parameters=[
header_user_name,
path_id
]
),
"CreateRole": APISchema(
summary="Create Role",
tags=["Roles"],
description="Create a new role",
parameters=[
header_user_name
]
),
"DeleteRoles": APISchema(
summary="Delete Roles",
tags=["Roles"],
description="Delete specify Roles",
parameters=[
header_user_name
]
),
"GetTraining": APISchema(
summary="Get Training",
tags=["Trainings"],
description="Get Training List",
parameters=[
header_user_name,
path_id
]
),
"ListCheckpoints": APISchema(
summary="List Checkpoints",
tags=["Checkpoints"],
description="List Checkpoints with Parameters",
parameters=[
header_user_name,
query_page,
query_per_page,
Parameter(name="username", description="Filter by username", location="query"),
]
),
"CreateCheckpoint": APISchema(
summary="Create Checkpoint",
tags=["Checkpoints"],
description="Create a new Checkpoint",
parameters=[
header_user_name
]
),
"DeleteCheckpoints": APISchema(
summary="Delete Checkpoints",
tags=["Checkpoints"],
description="Delete specify Checkpoints",
parameters=[
header_user_name
]
),
"StartInferences": APISchema(
summary="Start Inference Job",
tags=["Inferences"],
description="Start specify Inference Job by ID",
parameters=[
header_user_name,
path_id
]
),
"ListExecutes": APISchema(
summary="List Executes",
tags=["Executes"],
description="List Executes with Parameters",
parameters=[
header_user_name,
query_limit,
query_exclusive_start_key
]
),
"CreateExecute": APISchema(
summary="Create Execute",
tags=["Executes"],
description="Create a new Execute for Comfy",
parameters=[
header_user_name
]
),
"DeleteExecutes": APISchema(
summary="Delete Executes",
tags=["Executes"],
description="Delete specify Executes",
parameters=[
header_user_name
]
),
"GetApiOAS": APISchema(
summary="Get OAS",
description="Get OAS json schema",
tags=["Service"],
),
"ListUsers": APISchema(
summary="List Users",
tags=["Users"],
description="List all users",
parameters=[
header_user_name
]
),
"CreateUser": APISchema(
summary="Create User",
tags=["Users"],
description="Create a new user",
parameters=[
header_user_name
]
),
"DeleteUsers": APISchema(
summary="Delete Users",
tags=["Users"],
description="Delete specify Users",
parameters=[
header_user_name
]
),
"ListTrainings": APISchema(
summary="List Trainings",
tags=["Trainings"],
description="List Trainings with Parameters",
parameters=[
header_user_name,
query_limit,
query_exclusive_start_key
]
),
"CreateTraining": APISchema(
summary="Create Training",
tags=["Trainings"],
description="Create a new Training Job",
parameters=[
header_user_name
]
),
"DeleteTrainings": APISchema(
summary="Delete Trainings",
tags=["Trainings"],
description="Delete specify Trainings",
parameters=[
header_user_name
]
),
"GetExecute": APISchema(
summary="Get Execute",
tags=["Executes"],
description="Get Execute by ID",
parameters=[
header_user_name,
path_id
]
),
"GetExecuteLogs": APISchema(
summary="Get Execute Logs",
tags=["Executes"],
description="Get Execute Logs by ID",
parameters=[
header_user_name,
path_id
]
),
"ListDatasets": APISchema(
summary="List Datasets",
tags=["Datasets"],
description="List Datasets with Parameters",
parameters=[
header_user_name,
query_limit,
query_exclusive_start_key
]
),
"CropDataset": APISchema(
summary="Create new Crop Dataset",
tags=["Datasets"],
description="Create new Crop Dataset",
parameters=[
header_user_name,
path_dataset_name
]
),
"GetDataset": APISchema(
summary="Get Dataset",
tags=["Datasets"],
description="Get Dataset by ID",
parameters=[
header_user_name,
path_dataset_name
]
),
"UpdateCheckpoint": APISchema(
summary="Update Checkpoint",
tags=["Checkpoints"],
description="Update Checkpoint by ID",
parameters=[
header_user_name,
path_dataset_name
]
),
"CreateDataset": APISchema(
summary="Create Dataset",
tags=["Datasets"],
description="Create a new Dataset",
parameters=[
header_user_name
]
),
"DeleteDatasets": APISchema(
summary="Delete Datasets",
tags=["Datasets"],
description="Delete specify Datasets",
parameters=[
header_user_name
]
),
"UpdateDataset": APISchema(
summary="Update Dataset",
tags=["Datasets"],
description="Update Dataset by ID",
parameters=[
header_user_name,
path_dataset_name
]
),
"ListInferences": APISchema(
summary="List Inferences",
tags=["Inferences"],
description="List Inferences with Parameters",
parameters=[
header_user_name,
query_limit,
query_exclusive_start_key,
Parameter(name="type", description="Inference task type: txt2img, img2img", location="query"),
]
),
"CreateInferenceJob": APISchema(
summary="Create Inference Job",
tags=["Inferences"],
description="Create a new Inference Job",
parameters=[
header_user_name
]
),
"DeleteInferenceJobs": APISchema(
summary="Delete Inference Jobs",
tags=["Inferences"],
description="Delete specify Inference Jobs",
parameters=[
header_user_name
]
),
"ListEndpoints": APISchema(
summary="List Endpoints",
tags=["Endpoints"],
description="List Endpoints with Parameters",
parameters=[
header_user_name,
query_limit,
query_exclusive_start_key
]
),
"CreateEndpoint": APISchema(
summary="Create Endpoint",
tags=["Endpoints"],
description="Create a new Endpoint",
parameters=[
header_user_name
]
),
"DeleteEndpoints": APISchema(
summary="Delete Endpoints",
tags=["Endpoints"],
description="Delete specify Endpoints",
parameters=[
header_user_name
]
),
"SyncMessage": APISchema(
summary="Sync Message",
tags=["Sync"],
description="Sync Message to Endpoint",
parameters=[
header_user_name
]
),
"GetSyncMessage": APISchema(
summary="Get Sync Message",
description="Get Sync Message from Endpoint",
tags=["Sync"],
parameters=[
header_user_name
]
),
"CreatePrepare": APISchema(
summary="Create Prepare",
tags=["Prepare"],
description="Create a new Prepare",
parameters=[
header_user_name
]
),
"GetPrepare": APISchema(
summary="Get Prepare",
tags=["Prepare"],
description="Get Prepare by ID",
parameters=[
header_user_name
]
),
"CreateWorkflow": APISchema(
summary="Release newWorkflow",
tags=["Workflows"],
description="Create a new Workflow",
),
"ListWorkflows": APISchema(
summary="List Workflows",
tags=["Workflows"],
description="List Workflows with Parameters",
),
"DeleteWorkflows": APISchema(
summary="Delete Workflows",
tags=["Workflows"],
description="Delete specify Workflows",
),
}
@tracer.capture_lambda_handler
def handler(event: dict, context: LambdaContext):
logger.info(f'event: {event}')
logger.info(f'ctx: {context}')
api_id = event['requestContext']['apiId']
try:
response = client.get_export(
restApiId=api_id,
stageName='prod',
exportType='oas30',
accepts='application/json',
# parameters={
# 'extensions': 'apigateway'
# }
)
oas = response['body'].read()
json_schema = json.loads(oas)
json_schema = replace_null(json_schema)
json_schema['info']['version'] = esd_version.split('-')[0]
json_schema['servers'] = [
{
"url": "https://{ApiId}.execute-api.{Region}.{Domain}/prod/",
"variables": {
"ApiId": {
"default": "xxxxxx"
},
"Region": {
"default": "ap-northeast-1"
},
"Domain": {
"default": "amazonaws.com"
},
}
}
]
json_schema['info']['contact'] = {
"email": "elonniu@amazon.com",
}
json_schema['info']['license'] = {
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
}
json_schema['info']['description'] = (
"This is a ESD Server based on the OpenAPI 3.0 specification. \n"
"Some useful links: \n"
"\n- [The ESD Repository](https://github.com/awslabs/stable-diffusion-aws-extension)"
"\n- [Implementation Guide](https://awslabs.github.io/stable-diffusion-aws-extension/en/)")
json_schema['tags'] = tags
for path in json_schema['paths']:
for method in json_schema['paths'][path]:
meta = supplement_schema(json_schema['paths'][path][method])
json_schema['paths'][path][method]['description'] = meta.description
json_schema['paths'][path][method]['summary'] = meta.summary
json_schema['paths'][path][method]['tags'] = meta.tags
json_schema['paths'][path][method]['parameters'] = merge_parameters(meta,
json_schema['paths'][path][method])
json_schema['paths'] = dict(sorted(json_schema['paths'].items(), key=lambda x: x[0]))
payload = {
'isBase64Encoded': False,
'statusCode': 200,
'headers': {
'Content-Type': 'application/json',
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Headers': '*',
'Access-Control-Allow-Methods': '*',
'Access-Control-Allow-Credentials': True,
},
'body': json.dumps(json_schema, cls=DecimalEncoder, indent=2)
}
return payload
except Exception as e:
return response_error(e)
def merge_parameters(schema: APISchema, item: dict):
if not schema.parameters:
return []
if 'parameters' not in item or len(item['parameters']) == 0:
item['parameters'] = []
for param in schema.parameters:
item['parameters'].append(param.to_dict())
return item['parameters']
for param in schema.parameters:
update = False
for original_para in item['parameters']:
if param.name == original_para['name'] and param.location == original_para['in']:
update = True
original_para.update(param.to_dict())
if update is False:
item['parameters'].append(param.to_dict())
return item['parameters']
def replace_null(data):
if isinstance(data, dict):
for key, value in data.items():
if value is None:
data[key] = {
"type": "null",
"description": "Last Key for Pagination"
}
else:
data[key] = replace_null(value)
elif isinstance(data, list):
for i, item in enumerate(data):
if item is None:
data[i] = {
"type": "null",
"description": "Last Key for Pagination"
}
else:
data[i] = replace_null(item)
return data
def supplement_schema(method: any):
if 'operationId' in method:
if method['operationId'] in operations:
item: APISchema = operations[method['operationId']]
if item.parameters:
parameters = item.parameters
else:
parameters = []
return APISchema(
summary=item.summary + f" ({method['operationId']})",
tags=item.tags,
description=item.description,
parameters=parameters
)
return APISchema(
summary=method['operationId'],
tags=["Others"],
parameters=[]
)
return APISchema(
summary="",
tags=["Others"],
parameters=[]
)