stable-diffusion-aws-extension/aws_extension/cloud_infer_service/simple_sagemaker_infer.py

150 lines
5.9 KiB
Python

import json
import logging
import re
import requests
import utils
from aws_extension.cloud_api_manager.api_logger import ApiLogger
from aws_extension.cloud_infer_service.utils import InferManager
from utils import get_variable_from_json
logger = logging.getLogger(__name__)
logger.setLevel(utils.LOGGING_LEVEL)
class SimpleSagemakerInfer(InferManager):
def parse_lora(self, json_string: str, models):
prompt = json.loads(json_string)['prompt']
matches = re.findall(r"<lora:([^:>]+)", prompt)
lora_list = []
for match in matches:
lora_list.append(f"{match}.safetensors")
models['Lora'] = lora_list
return models
def run(self, userid, models, sd_param, is_txt2img, endpoint_type):
# finished construct api payload
sd_api_param_json = _parse_api_param_to_json(api_param=sd_param)
models = self.parse_lora(sd_api_param_json, models)
if logging.getLogger().getEffectiveLevel() == logging.DEBUG:
# debug only, may delete later
with open(f'api_{"txt2img" if is_txt2img else "img2img"}_param.json', 'w') as f:
f.write(sd_api_param_json)
# create an inference and upload to s3
# Start creating model on cloud.
url = get_variable_from_json('api_gateway_url')
api_key = get_variable_from_json('api_token')
if not url or not api_key:
logger.debug("Url or API-Key is not setting.")
return
payload = {
'inference_type': endpoint_type,
'task_type': "txt2img" if is_txt2img else "img2img",
'models': models,
}
logger.debug(payload)
inference_id = None
headers = {'x-api-key': api_key, 'username': userid}
response = requests.post(f'{url}inferences', json=payload, headers=headers)
infer_id = ""
if 'data' in response.json():
infer_id = response.json()['data']['inference']['id']
api_logger = ApiLogger(
action='inference',
infer_id=infer_id
)
api_logger.req_log(sub_action="CreateInference",
method='POST',
path=f'{url}inferences',
headers=headers,
response=response,
data=payload,
desc="Create inference job on cloud")
if response.status_code != 201:
raise Exception(response.json()['message'])
upload_param_response = response.json()['data']
if 'inference' in upload_param_response and \
'api_params_s3_upload_url' in upload_param_response['inference']:
api_params_s3_upload_url = upload_param_response['inference']['api_params_s3_upload_url']
upload_s3_resp = requests.put(api_params_s3_upload_url, data=sd_api_param_json)
upload_s3_resp.raise_for_status()
api_logger.req_log(sub_action="UploadParameterToS3",
method='PUT',
path=api_params_s3_upload_url,
data=sd_api_param_json,
desc="Upload inference parameter to S3 by presigned URL, "
"URL from previous step: CreateInference -> data -> inference -> api_params_s3_upload_url"
"<br/>Just use code to request, not use API tools to upload because they will change the headers to make the request invalid"
)
inference_id = upload_param_response['inference']['id']
# start run infer
start_url = f'{url}inferences/{inference_id}/start'
response = requests.put(start_url, headers={'x-api-key': api_key, 'username': userid})
api_logger.req_log(sub_action="StartInference",
method='PUT',
path=start_url,
headers=headers,
response=response,
desc=f"Start inference job on cloud by ID ({inference_id}), ID from previous step: "
"CreateInference -> data -> inference -> id")
if response.status_code not in [200, 202]:
logger.error(response.json())
raise Exception(response.json()['message'])
# if real-time, return inference data
if response.status_code == 200:
return response.json()['data']
return inference_id
def _parse_api_param_to_json(api_param):
import json
from PIL import Image, PngImagePlugin
from io import BytesIO
import base64
import numpy
import enum
def get_pil_metadata(pil_image):
# Copy any text-only metadata
metadata = PngImagePlugin.PngInfo()
for key, value in pil_image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
return metadata
def encode_pil_to_base64(pil_image):
with BytesIO() as output_bytes:
pil_image.save(output_bytes, "PNG", pnginfo=get_pil_metadata(pil_image))
bytes_data = output_bytes.getvalue()
base64_str = str(base64.b64encode(bytes_data), "utf-8")
return "data:image/png;base64," + base64_str
def encode_no_json(obj):
if isinstance(obj, numpy.ndarray):
return encode_pil_to_base64(Image.fromarray(obj))
elif isinstance(obj, Image.Image):
return encode_pil_to_base64(obj)
elif isinstance(obj, enum.Enum):
return obj.value
elif hasattr(obj, '__dict__'):
return obj.__dict__
else:
logger.debug(f'may not able to json dumps {type(obj)}: {str(obj)}')
return str(obj)
return json.dumps(api_param, default=encode_no_json)