stable-diffusion-aws-extension/middleware_api/lambda/inferences/start_inference_job.py

313 lines
12 KiB
Python

import base64
import io
import json
import logging
import os
from datetime import datetime
import boto3
from PIL import Image
from sagemaker import Predictor
from sagemaker.deserializers import JSONDeserializer
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.serializers import JSONSerializer
from common.ddb_service.client import DynamoDbUtilsService
from common.response import accepted, bad_request
from get_inference_job import get_infer_data
from libs.data_types import InferenceJob, InvocationsRequest
from libs.enums import EndpointType
S3_BUCKET_NAME = os.environ.get('S3_BUCKET_NAME')
inference_table_name = os.environ.get('INFERENCE_JOB_TABLE')
s3_client = boto3.client('s3')
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
ddb_client = boto3.resource('dynamodb')
ddb_service = DynamoDbUtilsService(logger=logger)
inference_table = ddb_client.Table(inference_table_name)
def upload_file_to_s3(file_name, bucket, directory=None, object_name=None):
# If S3 object_name was not specified, use file_name
if object_name is None:
object_name = file_name
# Add the directory to the object_name
if directory:
object_name = f"{directory}/{object_name}"
# Upload the file
try:
s3_client.upload_file(file_name, bucket, object_name)
print(f"File {file_name} uploaded to {bucket}/{object_name}")
except Exception as e:
print(f"Error occurred while uploading {file_name} to {bucket}/{object_name}: {e}")
return False
return True
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
return Image.open(io.BytesIO(base64.b64decode(encoding)))
def handler(event, _):
logger.info(json.dumps(event))
_filter = {}
inference_id = event['pathParameters']['id']
# get the inference job from ddb by job id
inference_raw = ddb_service.get_item(inference_table_name, {
'InferenceJobId': inference_id
})
assert inference_raw is not None and len(inference_raw) > 0
job = InferenceJob(**inference_raw)
endpoint_name = job.params['sagemaker_inference_endpoint_name']
models = {}
if 'used_models' in job.params:
models = {
"space_free_size": 4e10,
**job.params['used_models'],
}
payload = InvocationsRequest(
task=job.taskType,
username="test",
models=models,
param_s3=job.params['input_body_s3']
)
logger.info(f"payload: {payload}")
update_inference_job_table(job.InferenceJobId, 'startTime', str(datetime.now()))
if job.inference_type == EndpointType.RealTime.value:
return real_time(payload, job, endpoint_name)
else:
return async_inference(payload, job, endpoint_name)
def real_time(payload, job: InferenceJob, endpoint_name):
predictor = Predictor(endpoint_name)
predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()
try:
start_time = datetime.now()
prediction_sync = predictor.predict(data=payload.__dict__,
inference_id=job.InferenceJobId,
)
logger.info(prediction_sync)
if 'error' in prediction_sync:
if 'detail' in prediction_sync:
raise Exception(prediction_sync['detail'])
raise Exception(prediction_sync)
end_time = datetime.now()
cost_time = (end_time - start_time).total_seconds()
logger.info(f"Real-time inference cost_time: {cost_time}")
handle_sagemaker_out(job, prediction_sync, endpoint_name)
return get_infer_data(job.InferenceJobId)
except Exception as e:
print(e)
return bad_request(message=str(e))
def async_inference(payload, job: InferenceJob, endpoint_name):
predictor = Predictor(endpoint_name)
initial_args = {"InvocationTimeoutSeconds": 3600}
predictor = AsyncPredictor(predictor, name=endpoint_name)
predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()
prediction = predictor.predict_async(data=payload.__dict__, initial_args=initial_args,
inference_id=job.InferenceJobId)
logger.info(f"prediction: {prediction}")
output_path = prediction.output_path
# update the ddb job status to 'inprogress' and save to ddb
job.status = 'inprogress'
job.params['output_path'] = output_path
ddb_service.put_items(inference_table_name, job.__dict__)
data = {
'inference': {
'inference_id': job.InferenceJobId,
'status': job.status,
'endpoint_name': endpoint_name,
'output_path': output_path
}
}
return accepted(data=data)
def handle_sagemaker_out(job: InferenceJob, json_body, endpoint_name):
update_inference_job_table(job.InferenceJobId, 'completeTime', str(datetime.now()))
inference_id = job.InferenceJobId
taskType = job.taskType
try:
if taskType in ["interrogate_clip", "interrogate_deepbooru"]:
caption = json_body['caption']
# Update the DynamoDB table for the caption
inference_table.update_item(
Key={
'InferenceJobId': inference_id
},
UpdateExpression='SET caption=:f',
ExpressionAttributeValues={
':f': caption,
}
)
elif taskType in ["txt2img", "img2img"]:
logger.debug(f'image count:{len(json_body["images"])}')
# save images
for count, b64image in enumerate(json_body["images"]):
output_img_type = None
if 'output_img_type' in json_body and json_body['output_img_type']:
output_img_type = json_body['output_img_type']
logger.info(f"sync handle_sagemaker_out: output_img_type is not null, {output_img_type}")
if not output_img_type:
image = decode_base64_to_image(b64image).convert("RGB")
output = io.BytesIO()
image.save(output, format="PNG")
# Upload the image to the S3 bucket
s3_client.put_object(
Body=output.getvalue(),
Bucket=S3_BUCKET_NAME,
Key=f"out/{inference_id}/result/image_{count}.png"
)
else:
gif_data = base64.b64decode(b64image.split(",", 1)[0])
if len(output_img_type) == 1 and (output_img_type[0] == 'PNG' or output_img_type[0] == 'TXT'):
logger.debug(f'output_img_type len is 1 :{output_img_type[0]} {count}')
img_type = 'png'
elif len(output_img_type) == 2 and ('PNG' in output_img_type and 'TXT' in output_img_type):
logger.debug(f'output_img_type len is 2 :{output_img_type[0]} {output_img_type[1]} {count}')
img_type = 'png'
else:
img_type = 'gif'
output_img_type = [element for element in output_img_type if
"TXT" not in element and "PNG" not in element]
logger.debug(f'output_img_type new is :{output_img_type} {count}')
# type set
image_count = len(json_body["images"])
type_count = len(output_img_type)
if image_count % type_count == 0:
idx = count % type_count
img_type = output_img_type[idx].lower()
logger.debug(f'img_type is :{img_type} count is:{count}')
s3_client.put_object(
Body=gif_data,
Bucket=S3_BUCKET_NAME,
Key=f"out/{inference_id}/result/image_{count}.{img_type}"
)
# Update the DynamoDB table
inference_table.update_item(
Key={
'InferenceJobId': inference_id
},
UpdateExpression='SET image_names = list_append(if_not_exists(image_names, :empty_list), :new_image)',
ExpressionAttributeValues={
':new_image': [f"image_{count}.png"],
':empty_list': []
}
)
# save parameters
inference_parameters = {}
inference_parameters["parameters"] = json_body["parameters"]
inference_parameters["info"] = json_body["info"]
inference_parameters["endpont_name"] = endpoint_name
inference_parameters["inference_id"] = inference_id
json_file_name = f"/tmp/{inference_id}_param.json"
with open(json_file_name, "w") as outfile:
json.dump(inference_parameters, outfile)
upload_file_to_s3(json_file_name, S3_BUCKET_NAME, f"out/{inference_id}/result",
f"{inference_id}_param.json")
update_inference_job_table(inference_id, 'inference_info_name', json_file_name)
elif taskType in ["extra-single-image", "rembg"]:
if 'image' not in json_body:
raise Exception(json_body)
# image = decode_base64_to_image(json_body["image"]).convert("RGB")
image = Image.open(io.BytesIO(base64.b64decode(json_body["image"])))
output = io.BytesIO()
image.save(output, format="PNG")
# Upload the image to the S3 bucket
s3_client.put_object(
Body=output.getvalue(),
Bucket=S3_BUCKET_NAME,
Key=f"out/{inference_id}/result/image.png"
)
# Update the DynamoDB table
inference_table.update_item(
Key={
'InferenceJobId': inference_id
},
UpdateExpression='SET image_names = list_append(if_not_exists(image_names, :empty_list), :new_image)',
ExpressionAttributeValues={
':new_image': [f"image.png"],
':empty_list': []
}
)
# save parameters
inference_parameters = {}
if taskType == "extra-single-image":
inference_parameters["html_info"] = json_body["html_info"]
inference_parameters["endpont_name"] = endpoint_name
inference_parameters["inference_id"] = inference_id
json_file_name = f"/tmp/{inference_id}_param.json"
with open(json_file_name, "w") as outfile:
json.dump(inference_parameters, outfile)
upload_file_to_s3(json_file_name, S3_BUCKET_NAME, f"out/{inference_id}/result",
f"{inference_id}_param.json")
update_inference_job_table(inference_id, 'inference_info_name', json_file_name)
print(f"Complete inference parameters {inference_parameters}")
update_inference_job_table(inference_id, 'status', 'succeed')
except Exception as e:
print(f"Error occurred: {str(e)}")
update_inference_job_table(inference_id, 'status', 'failed')
raise e
def update_inference_job_table(inference_id, key, value):
# Update the inference DDB for the job status
response = inference_table.get_item(
Key={
"InferenceJobId": inference_id,
})
inference_resp = response['Item']
if not inference_resp:
raise Exception(f"Failed to get the inference job item with inference id: {inference_id}")
response = inference_table.update_item(
Key={
"InferenceJobId": inference_id,
},
UpdateExpression=f"set #k = :r",
ExpressionAttributeNames={'#k': key},
ExpressionAttributeValues={':r': value},
ReturnValues="UPDATED_NEW"
)