stable-diffusion-aws-extension/aws_extension/cloud_infer_service/simple_sagemaker_infer.py

105 lines
3.8 KiB
Python

import logging
import requests
from datetime import datetime
import utils
from aws_extension.cloud_infer_service.utils import InferManager
from utils import get_variable_from_json
logger = logging.getLogger(__name__)
logger.setLevel(utils.LOGGING_LEVEL)
class SimpleSagemakerInfer(InferManager):
def run(self, userid, models, sd_param, is_txt2img):
# finished construct api payload
sd_api_param_json = _parse_api_param_to_json(api_param=sd_param)
if logging.getLogger().getEffectiveLevel() == logging.DEBUG:
# debug only, may delete later
with open(f'api_{"txt2img" if is_txt2img else "img2img"}_param.json', 'w') as f:
f.write(sd_api_param_json)
# create an inference and upload to s3
# Start creating model on cloud.
url = get_variable_from_json('api_gateway_url')
api_key = get_variable_from_json('api_token')
if not url or not api_key:
logger.debug("Url or API-Key is not setting.")
return
payload = {
# 'sagemaker_endpoint_name': sagemaker_endpoint,
'user_id': userid,
'task_type': "txt2img" if is_txt2img else "img2img",
'models': models,
'filters': {
'createAt': datetime.now().timestamp(),
'creator': 'sd-webui'
}
}
logger.debug(payload)
inference_id = None
response = requests.post(f'{url}inference/v2', json=payload, headers={'x-api-key': api_key})
response.raise_for_status()
upload_param_response = response.json()
if upload_param_response['statusCode'] != 200:
raise Exception(upload_param_response['errMsg'])
if 'inference' in upload_param_response and \
'api_params_s3_upload_url' in upload_param_response['inference']:
upload_s3_resp = requests.put(upload_param_response['inference']['api_params_s3_upload_url'],
data=sd_api_param_json)
upload_s3_resp.raise_for_status()
inference_id = upload_param_response['inference']['id']
# start run infer
response = requests.put(f'{url}inference/v2/{inference_id}/run', json=payload,
headers={'x-api-key': api_key})
response.raise_for_status()
return inference_id
def _parse_api_param_to_json(api_param):
import json
from PIL import Image, PngImagePlugin
from io import BytesIO
import base64
import numpy
import enum
def get_pil_metadata(pil_image):
# Copy any text-only metadata
metadata = PngImagePlugin.PngInfo()
for key, value in pil_image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
return metadata
def encode_pil_to_base64(pil_image):
with BytesIO() as output_bytes:
pil_image.save(output_bytes, "PNG", pnginfo=get_pil_metadata(pil_image))
bytes_data = output_bytes.getvalue()
base64_str = str(base64.b64encode(bytes_data), "utf-8")
return "data:image/png;base64," + base64_str
def encode_no_json(obj):
if isinstance(obj, numpy.ndarray):
return encode_pil_to_base64(Image.fromarray(obj))
# return obj.tolist()
# return "base64 str"
elif isinstance(obj, Image.Image):
return encode_pil_to_base64(obj)
elif isinstance(obj, enum.Enum):
return obj.value
elif hasattr(obj, '__dict__'):
return obj.__dict__
else:
logger.debug(f'may not able to json dumps {type(obj)}: {str(obj)}')
return str(obj)
return json.dumps(api_param, default=encode_no_json)