stable-diffusion-aws-extension/middleware_api/lambda/inference/common/response_wrapper.py

79 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# todo will remove
import json
import logging
import os
from typing import Optional, Generic, TypeVar
from common.enum import MessageEnum
from fastapi import Response
from fastapi.responses import JSONResponse
from pydantic.generics import GenericModel
from .constant import const
logger = logging.getLogger(const.LOGGER_API)
DataT = TypeVar("DataT")
class BaseResponse(GenericModel, Generic[DataT]):
status: str
code: int
message: Optional[str]
data: Optional[DataT] = None
class Config:
orm_mode = True
class DbWrapper:
def __init__(self, status, data, code, message):
self.status = status
self.data = data
self.code = code
self.message = message
# Be called when a correct response is returned, e.g.resp_okdata=xxxx)
def resp_ok(data,
code: int = MessageEnum.BIZ_DEFAULT_OK.get_code(),
message: str = MessageEnum.BIZ_DEFAULT_OK.get_msg()) -> DbWrapper:
return DbWrapper(const.RESPONSE_SUCCESS, data, code, message)
# Be called when a exception response is returned, e.g.resp_ngcode=XXX, message=XXXXX)
def resp_err(code: int = MessageEnum.BIZ_DEFAULT_ERR.get_code(),
message: str = MessageEnum.BIZ_DEFAULT_ERR.get_msg(),
ref: list = None) -> Response:
logger.info('END >>> RESPONSE FAILED')
headers = {}
if os.getenv(const.MODE) == const.MODE_DEV:
headers = {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Headers': '*',
'Access-Control-Allow-Methods': '*',
'Access-Control-Allow-Credentials': 'true'
}
return JSONResponse(
content={
'status': const.RESPONSE_FAIL, # success|fail
'code': code,
'message': message,
'ref': ref,
},
headers=headers,
)
class S3WrapEncoder(json.JSONEncoder):
def convert(obj, properties: list[str]):
fields = {}
for field in [x for x in dir(obj) if not x.startswith('_') and x != 'metadata' and x in properties]:
data = obj.__getattribute__(field)
try:
json.dumps(data)
fields[field] = data
except TypeError:
fields[field] = None
return fields