import json import logging.config import logging.config import os import traceback import uuid from datetime import datetime import boto3 from boto3.dynamodb.conditions import Attr, Key from botocore.exceptions import ClientError from fastapi import FastAPI, Request, HTTPException from fastapi.responses import JSONResponse from fastapi_pagination import add_pagination from mangum import Mangum from sagemaker.deserializers import JSONDeserializer from sagemaker.predictor import Predictor from sagemaker.predictor_async import AsyncPredictor from sagemaker.serializers import JSONSerializer from common.constant import const from parse.parameter_parser import json_convert_to_payload logging.config.fileConfig('logging.conf', disable_existing_loggers=False) logger = logging.getLogger(const.LOGGER_API) INFERENCE_JOB_TABLE = os.environ.get('INFERENCE_JOB_TABLE') DDB_TRAINING_TABLE_NAME = os.environ.get('DDB_TRAINING_TABLE_NAME') DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME = os.environ.get('DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME') REGION_NAME = os.environ['AWS_REGION'] S3_BUCKET_NAME = os.environ.get('S3_BUCKET') ddb_client = boto3.resource('dynamodb') s3 = boto3.client('s3', region_name=REGION_NAME) sagemaker = boto3.client('sagemaker') inference_table = ddb_client.Table(INFERENCE_JOB_TABLE) endpoint_deployment_table = ddb_client.Table(DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME) async def custom_exception_handler(request: Request, exc: HTTPException): headers = { "Access-Control-Allow-Headers": "Content-Type", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "OPTIONS,POST,GET" } return JSONResponse( status_code=exc.status_code, content={"detail": exc.detail}, headers=headers ) app = FastAPI( title="API List of SageMaker Inference", version="0.9", ) app.exception_handler(HTTPException)(custom_exception_handler) def get_uuid(): uuid_str = str(uuid.uuid4()) return uuid_str def getInferenceJobList(): response = inference_table.scan() logger.info(f"inference job list response is {str(response)}") return response['Items'] def build_filter_expression(end_time, endpoint, start_time, status, task_type): filter_expression = None if status: filter_expression = Attr('status').eq(status) if task_type: if filter_expression: filter_expression &= Attr('taskType').eq(task_type) else: filter_expression = Attr('taskType').eq(task_type) if start_time: if filter_expression: filter_expression &= Attr('startTime').gte(start_time) else: filter_expression = Attr('startTime').gte(start_time) if end_time: if filter_expression: filter_expression &= Attr('startTime').lte(end_time) else: filter_expression = Attr('startTime').lte(end_time) if endpoint: if filter_expression: filter_expression &= Attr('params.sagemaker_inference_endpoint_name').eq(endpoint) else: filter_expression = Attr('params.sagemaker_inference_endpoint_name').eq(endpoint) return filter_expression def query_inference_job_list(status: str, task_type: str, start_time: str, end_time: str, endpoint: str, checkpoint: str, limit: int): print(f"query_inference_job_list params are:{status},{task_type},{start_time},{end_time},{checkpoint},{endpoint}") try: response = None filter_expression = build_filter_expression(end_time, endpoint, start_time, status, task_type) if limit != const.PAGE_LIMIT_ALL and limit <= 0: logger.info(f"query inference job list error because of limit <0 {limit}") return "" if filter_expression: response = inference_table.scan( FilterExpression=filter_expression ) else: response = inference_table.scan() logger.info(f"query inference job list response is {str(response)}") if response: return filter_checkpoint_items(limit, checkpoint, response['Items']) return response except Exception as e: logger.info(f"query inference job list error ") logger.info(e) return "" def sort_by_start_time(item): return item.get("startTime", "") def filter_checkpoint_items(limit, checkpoint, items): items = sorted(items, key=sort_by_start_time, reverse=True) if checkpoint: filtered_data = [] for item in items: if "params" in item and "used_models" in item["params"]: used_models = item["params"]["used_models"].get("Stable-diffusion", []) for model in used_models: if "model_name" in model and model["model_name"] == checkpoint: filtered_data.append(item) if limit == const.PAGE_LIMIT_ALL: return filtered_data else: if len(filtered_data) >= limit: return filtered_data[0: limit] else: return filtered_data if limit == const.PAGE_LIMIT_ALL: return items else: if len(items) >= limit: return items[0: limit] else: return items def getInferenceJob(inference_job_id): if not inference_job_id: logger.error("Invalid inference job id") raise ValueError("Inference job id must not be None or empty") try: resp = inference_table.query( KeyConditionExpression=Key('InferenceJobId').eq(inference_job_id) ) # logger.info(resp) record_list = resp['Items'] if len(record_list) == 0: logger.error(f"No inference job info item for id: {inference_job_id}") raise ValueError(f"There is no inference job info item for id: {inference_job_id}") return record_list[0] except Exception as e: logger.error( f"Exception occurred when trying to query inference job with id: {inference_job_id}, exception is {str(e)}") raise def getEndpointDeploymentJobList(): try: sagemaker = boto3.client('sagemaker') ddb = boto3.resource('dynamodb') endpoint_deployment_table = ddb.Table(DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME) response = endpoint_deployment_table.scan() logger.info(f"endpoint deployment job list response is {str(response)}") # Get the list of SageMaker endpoints list_results = sagemaker.list_endpoints() sagemaker_endpoints = [ep_info['EndpointName'] for ep_info in list_results['Endpoints']] logger.info(str(sagemaker_endpoints)) # Filter the endpoint job list filtered_endpoint_jobs = [] for job in response['Items']: if 'endpoint_name' in job: endpoint_name = job['endpoint_name'] deployment_job_id = job['EndpointDeploymentJobId'] if endpoint_name in sagemaker_endpoints: filtered_endpoint_jobs.append(job) else: # Remove the job item from the DynamoDB table if the endpoint doesn't exist in SageMaker endpoint_deployment_table.delete_item(Key={'EndpointDeploymentJobId': deployment_job_id}) else: filtered_endpoint_jobs.append(job) return filtered_endpoint_jobs except ClientError as e: print(f"An error occurred: {e}") return [] except Exception as e: print(f"An unexpected error occurred: {e}") return [] def getEndpointDeployJob(endpoint_deploy_job_id): try: resp = endpoint_deployment_table.query( KeyConditionExpression=Key('EndpointDeploymentJobId').eq(endpoint_deploy_job_id) ) logger.info(resp) except Exception as e: logger.error(e) record_list = resp['Items'] if len(record_list) == 0: logger.error("There is no endpoint deployment job info item for id:" + endpoint_deploy_job_id) return {} return record_list[0] def getEndpointDeployJob_with_endpoint_name(endpoint_name): try: resp = endpoint_deployment_table.scan( FilterExpression=Attr('endpoint_name').eq(endpoint_name) ) logger.info(resp) except Exception as e: logger.error(e) record_list = resp['Items'] if len(record_list) == 0: logger.error("There is no endpoint deployment job info item with endpoint name:" + endpoint_name) return {} return record_list[0] def get_s3_objects(bucket_name, folder_name): # Ensure the folder name ends with a slash if not folder_name.endswith('/'): folder_name += '/' # List objects in the specified bucket and folder response = s3.list_objects_v2(Bucket=bucket_name, Prefix=folder_name) # Extract object names from the response object_names = [obj['Key'][len(folder_name):] for obj in response.get('Contents', []) if obj['Key'] != folder_name] return object_names def load_json_from_s3(bucket_name, key): # Get the JSON file from the specified bucket and key response = s3.get_object(Bucket=bucket_name, Key=key) json_file = response['Body'].read().decode('utf-8') # Load the JSON file into a dictionary data = json.loads(json_file) return data # Global exception capture stepf_client = boto3.client('stepfunctions') # def get_curent_time(): # # Get the current time # now = datetime.now() # formatted_time = now.strftime("%Y-%m-%d-%H-%M-%S") # return formatted_time @app.post("/inference/run-sagemaker-inference") @app.post("/inference-api/inference") async def run_sagemaker_inference(request: Request): try: logger.info('entering the run_sage_maker_inference function!') inference_id = get_uuid() payload_checkpoint_info = await request.json() print(f"!!!!!!!!!!input in json format {payload_checkpoint_info}") task_type = payload_checkpoint_info.get('task_type') print(f"Task Type: {task_type}") path = request.url.path logger.info(f'Path: {path}') if path == '/inference-api/inference': # Invoke by API logger.info('invoked by api') params_dict = load_json_from_s3(S3_BUCKET_NAME, 'template/inferenceTemplate.json') else: # Invoke by UI params_dict = load_json_from_s3(S3_BUCKET_NAME, 'config/aigc.json') # logger.info(json.dumps(params_dict)) payload = json_convert_to_payload(params_dict, payload_checkpoint_info, task_type) print(f"input in json format:") checkpoint_name = None if task_type == 'img2img': checkpoint_name = params_dict['img2img_sagemaker_stable_diffusion_checkpoint'] elif task_type == 'txt2img': checkpoint_name = params_dict['txt2img_sagemaker_stable_diffusion_checkpoint'] def show_slim_dict(payload): pay_type = type(payload) if pay_type is dict: for k, v in payload.items(): print(f"{k}") show_slim_dict(v) elif pay_type is list: for v in payload: print(f"list") show_slim_dict(v) elif pay_type is str: if len(payload) > 100: print(f" : {len(payload)} contents") else: print(f" : {payload}") else: print(f" : {payload}") show_slim_dict(payload) endpoint_name = payload["endpoint_name"] predictor = Predictor(endpoint_name) # adjust time out time to 1 hour initial_args = {"InvocationTimeoutSeconds": 3600} predictor = AsyncPredictor(predictor, name=endpoint_name) predictor.serializer = JSONSerializer() predictor.deserializer = JSONDeserializer() prediction = predictor.predict_async(data=payload, initial_args=initial_args, inference_id=inference_id) output_path = prediction.output_path # put the item to inference DDB for later check status current_time = str(datetime.now()) response = inference_table.put_item( Item={ 'InferenceJobId': inference_id, 'startTime': current_time, 'status': 'inprogress', 'endpoint': endpoint_name, 'checkpoint': checkpoint_name, 'taskType': task_type }) print(f"output_path is {output_path}") headers = { "Access-Control-Allow-Headers": "Content-Type", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "OPTIONS,POST,GET" } response = JSONResponse( content={"inference_id": inference_id, "status": "inprogress", "endpoint_name": endpoint_name, "output_path": output_path}, headers=headers) return response except Exception as e: traceback.print_exc() logger.error(f"Error occurred: {str(e)}") # raise HTTPException(status_code=500, detail=f"An error occurred during processing.{str(e)}") headers = { "Access-Control-Allow-Headers": "Content-Type", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "OPTIONS,POST,GET" } current_time = str(datetime.now()) response = inference_table.put_item( Item={ 'InferenceJobId': inference_id, 'startTime': current_time, 'completeTime': current_time, 'status': 'failure', 'endpoint': endpoint_name, 'checkpoint': checkpoint_name, 'taskType': task_type or "unknown", 'error': f"error info {str(e)}"} ) response = JSONResponse( content={"inference_id": inference_id, "status": "failure", "error": f"error info {str(e)}"}, headers=headers) return response # todo will remove @app.post("/inference/query-inference-jobs") async def query_inference_jobs(request: Request): logger.info(f"entering query-inference-jobs") query_params = await request.json() logger.info(query_params) status = query_params.get('status') task_type = query_params.get('task_type') start_time = query_params.get('start_time') end_time = query_params.get('end_time') endpoint = query_params.get('endpoint') checkpoint = query_params.get('checkpoint') limit = query_params.get("limit") if query_params.get("limit") else const.PAGE_LIMIT_ALL logger.info( f"entering query-inference-jobs {status},{task_type},{start_time},{end_time},{checkpoint},{endpoint},{limit}") return query_inference_job_list(status, task_type, start_time, end_time, endpoint, checkpoint, limit) def generate_presigned_url(bucket_name: str, key: str, expiration=3600) -> str: try: response = s3.generate_presigned_url( 'get_object', Params={'Bucket': bucket_name, 'Key': key}, ExpiresIn=expiration ) except Exception as e: logger.error(f"Error generating presigned URL: {e}") raise return response @app.get("/inference/generate-s3-presigned-url-for-uploading") async def generate_s3_presigned_url_for_uploading(s3_bucket_name: str = None, key: str = None): if not s3_bucket_name: s3_bucket_name = S3_BUCKET_NAME if not key: raise HTTPException(status_code=400, detail="Key parameter is required") try: presigned_url = s3.generate_presigned_url( 'put_object', Params={ 'Bucket': s3_bucket_name, 'Key': key, 'ContentType': 'text/plain;charset=UTF-8' }, ExpiresIn=3600, HttpMethod='PUT' ) except Exception as e: headers = { "Access-Control-Allow-Headers": "*", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "OPTIONS,POST,GET,PUT" } return JSONResponse(content=str(e), status_code=500, headers=headers) headers = { "Access-Control-Allow-Headers": "*", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "OPTIONS,POST,GET,PUT" } response = JSONResponse(content=presigned_url, headers=headers) return response @app.post("/inference/run-model-merge") async def run_model_merge(request: Request): try: logger.info('entering the run_model_merge function!') # TODO: add logic for inference id merge_id = get_uuid() payload_checkpoint_info = await request.json() print(f"!!!!!!!!!!input in json format {payload_checkpoint_info}") params_dict = load_json_from_s3(S3_BUCKET_NAME, 'config/aigc.json') logger.info(json.dumps(params_dict)) payload = json_convert_to_payload(params_dict, payload_checkpoint_info) print(f"input in json format {payload}") task_type = payload_checkpoint_info.get('task_type') endpoint_name = payload["endpoint_name"] checkpoint_name = None if task_type == 'img2img': checkpoint_name = params_dict['img2img_sagemaker_stable_diffusion_checkpoint'] elif task_type == 'txt2img': checkpoint_name = params_dict['txt2img_sagemaker_stable_diffusion_checkpoint'] predictor = Predictor(endpoint_name) predictor = AsyncPredictor(predictor, name=endpoint_name) predictor.serializer = JSONSerializer() predictor.deserializer = JSONDeserializer() prediction = predictor.predict_async(data=payload, inference_id=inference_id) output_path = prediction.output_path # put the item to inference DDB for later check status current_time = str(datetime.now()) response = inference_table.put_item( Item={ 'InferenceJobId': inference_id, 'startTime': current_time, 'status': 'inprogress', 'endpoint': endpoint_name, 'checkpoint': checkpoint_name, }) print(f"output_path is {output_path}") headers = { "Access-Control-Allow-Headers": "Content-Type", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "OPTIONS,POST,GET" } response = JSONResponse( content={"inference_id": inference_id, "status": "inprogress", "endpoint_name": endpoint_name, "output_path": output_path}, headers=headers) # response = JSONResponse(content={"inference_id": '6fa743f0-cb7a-496f-8205-dbd67df08be2', "status": "succeed", "output_path": ""}, headers=headers) return response except Exception as e: logger.error(f"Error occurred: {str(e)}") # raise HTTPException(status_code=500, detail=f"An error occurred during processing.{str(e)}") headers = { "Access-Control-Allow-Headers": "Content-Type", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "OPTIONS,POST,GET" } response = JSONResponse( content={"inference_id": inference_id, "status": "failure", "error": f"error info {str(e)}"}, headers=headers) return response # app.include_router(search) TODO: adding sub router for future handler = Mangum(app) add_pagination(app)