531 lines
15 KiB
Python
531 lines
15 KiB
Python
import base64
|
|
import datetime
|
|
import enum
|
|
import json
|
|
import logging
|
|
import os
|
|
from functools import reduce
|
|
from io import BytesIO
|
|
from typing import Dict
|
|
|
|
import boto3
|
|
import numpy
|
|
from PIL import Image, PngImagePlugin
|
|
from aws_lambda_powertools import Tracer
|
|
|
|
from common.ddb_service.client import DynamoDbUtilsService
|
|
from libs.comfy_data_types import InferenceResult
|
|
from libs.data_types import Endpoint
|
|
from libs.enums import ServiceType
|
|
from libs.utils import log_json
|
|
|
|
tracer = Tracer()
|
|
s3 = boto3.client('s3')
|
|
s3_resource = boto3.resource('s3')
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
|
|
sns_client = boto3.client('sns')
|
|
s3_client = boto3.client('s3')
|
|
bucket_name = os.environ.get('S3_BUCKET_NAME')
|
|
s3_bucket = s3_resource.Bucket(bucket_name)
|
|
ddb_service = DynamoDbUtilsService(logger=logger)
|
|
cloudwatch = boto3.client('cloudwatch')
|
|
sagemaker = boto3.client('sagemaker')
|
|
endpoint_name = os.getenv('ENDPOINT_NAME')
|
|
endpoint_instance_id = os.getenv('ENDPOINT_INSTANCE_ID')
|
|
sagemaker_endpoint_table = os.environ.get('ENDPOINT_TABLE_NAME')
|
|
esd_version = os.environ.get("ESD_VERSION")
|
|
logs = boto3.client('logs')
|
|
|
|
|
|
def record_count_metrics(ep_name: str,
|
|
metric_name='InferenceSucceed',
|
|
service=ServiceType.SD.value,
|
|
workflow: str = None
|
|
):
|
|
data = [
|
|
{
|
|
'MetricName': metric_name,
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Service',
|
|
'Value': service
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': 1,
|
|
'Unit': 'Count'
|
|
},
|
|
{
|
|
'MetricName': metric_name,
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Endpoint',
|
|
'Value': ep_name
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': 1,
|
|
'Unit': 'Count'
|
|
},
|
|
]
|
|
|
|
if workflow:
|
|
data.append({
|
|
'MetricName': metric_name,
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Workflow',
|
|
'Value': workflow
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': 1,
|
|
'Unit': 'Count'
|
|
})
|
|
|
|
response = cloudwatch.put_metric_data(
|
|
Namespace='ESD',
|
|
MetricData=data
|
|
)
|
|
|
|
logger.info(f"record_metric response: {response}")
|
|
|
|
|
|
def record_seconds_metrics(start_time: str, metric_name='Inference', service=ServiceType.SD.value):
|
|
start_time = datetime.datetime.fromisoformat(start_time)
|
|
latency = (datetime.datetime.now() - start_time).seconds
|
|
|
|
response = cloudwatch.put_metric_data(
|
|
Namespace='ESD',
|
|
MetricData=[
|
|
{
|
|
'MetricName': metric_name,
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Service',
|
|
'Value': service
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': latency,
|
|
'Unit': 'Seconds'
|
|
},
|
|
]
|
|
)
|
|
logger.info(f"record_metric response: {response}")
|
|
|
|
|
|
def record_latency_metrics(start_time,
|
|
ep_name: str,
|
|
metric_name='InferenceLatency',
|
|
service=ServiceType.SD.value,
|
|
workflow: str = None
|
|
):
|
|
logger.info(f"start {start_time}")
|
|
|
|
end_time = datetime.datetime.now().isoformat()
|
|
logger.info(f"end {end_time}")
|
|
|
|
time1 = datetime.datetime.strptime(start_time, "%Y-%m-%dT%H:%M:%S.%f")
|
|
time2 = datetime.datetime.strptime(end_time, "%Y-%m-%dT%H:%M:%S.%f")
|
|
|
|
time_difference = time2 - time1
|
|
|
|
latency = time_difference.total_seconds() * 1000
|
|
|
|
logger.info(f"{service} {metric_name}: {latency} Milliseconds")
|
|
|
|
data = [
|
|
{
|
|
'MetricName': metric_name,
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Service',
|
|
'Value': service
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': latency,
|
|
'Unit': 'Milliseconds'
|
|
},
|
|
{
|
|
'MetricName': metric_name,
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Endpoint',
|
|
'Value': ep_name
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': latency,
|
|
'Unit': 'Milliseconds'
|
|
},
|
|
]
|
|
|
|
if workflow:
|
|
data.append({
|
|
'MetricName': metric_name,
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Workflow',
|
|
'Value': workflow
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': latency,
|
|
'Unit': 'Milliseconds'
|
|
})
|
|
|
|
response = cloudwatch.put_metric_data(
|
|
Namespace='ESD',
|
|
MetricData=data
|
|
)
|
|
logger.info(f"record_metric response: {response}")
|
|
|
|
|
|
def record_queue_latency_metrics(create_time: str, start_time, ep_name: str, service=ServiceType.SD.value):
|
|
metric_name = 'QueueLatency'
|
|
|
|
logger.info(f"create_time {create_time}")
|
|
logger.info(f"start_time {start_time}")
|
|
|
|
time1 = datetime.datetime.strptime(create_time, "%Y-%m-%dT%H:%M:%S.%f")
|
|
time2 = datetime.datetime.strptime(start_time, "%Y-%m-%dT%H:%M:%S.%f")
|
|
|
|
time_difference = time2 - time1
|
|
|
|
latency = time_difference.total_seconds() * 1000
|
|
|
|
if latency < 0:
|
|
latency = 0
|
|
|
|
logger.info(f"{service} {metric_name}: {latency} Milliseconds")
|
|
|
|
data = [
|
|
{
|
|
'MetricName': metric_name,
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Service',
|
|
'Value': service
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': latency,
|
|
'Unit': 'Milliseconds'
|
|
},
|
|
{
|
|
'MetricName': metric_name,
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Endpoint',
|
|
'Value': ep_name
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': latency,
|
|
'Unit': 'Milliseconds'
|
|
},
|
|
]
|
|
|
|
response = cloudwatch.put_metric_data(
|
|
Namespace='ESD',
|
|
MetricData=data
|
|
)
|
|
logger.info(f"record_metric response: {response}")
|
|
|
|
|
|
def get_multi_query_params(event, param_name: str, default=None):
|
|
value = default
|
|
if 'multiValueQueryStringParameters' in event:
|
|
multi_query = event['multiValueQueryStringParameters']
|
|
if multi_query and param_name in multi_query and len(multi_query[param_name]) > 0:
|
|
value = multi_query[param_name]
|
|
|
|
return value
|
|
|
|
|
|
def get_query_param(event, param_name: str, default=None):
|
|
if 'queryStringParameters' in event:
|
|
queries = event['queryStringParameters']
|
|
if queries and param_name in queries:
|
|
return queries[param_name]
|
|
|
|
return default
|
|
|
|
|
|
def resolve_instance_invocations_num(instance_type: str, service_type: str):
|
|
if service_type == "sd":
|
|
return 1
|
|
|
|
if instance_type == 'ml.g5.12xlarge':
|
|
return 4
|
|
|
|
if instance_type == 'ml.p4d.24xlarge':
|
|
return 8
|
|
|
|
return 1
|
|
|
|
|
|
def query_data(data, paths):
|
|
value = data
|
|
for path in paths:
|
|
value = value.get(path)
|
|
if not value:
|
|
path_string = reduce(lambda x, y: f"{x}.{y}", paths)
|
|
raise ValueError(f"Missing {path_string}")
|
|
|
|
return value
|
|
|
|
|
|
def publish_msg(topic_arn, msg, subject):
|
|
sns_client.publish(
|
|
TopicArn=topic_arn,
|
|
Message=str(msg),
|
|
Subject=subject
|
|
)
|
|
|
|
|
|
def get_s3_presign_urls(bucket_name, base_key, filenames) -> Dict[str, str]:
|
|
return _get_s3_presign_urls(bucket_name, base_key, filenames, expires=3600 * 24 * 7, method='put_object')
|
|
|
|
|
|
def get_s3_get_presign_urls(bucket_name, base_key, filenames) -> Dict[str, str]:
|
|
return _get_s3_presign_urls(bucket_name, base_key, filenames, expires=3600 * 24, method='get_object')
|
|
|
|
|
|
def _get_s3_presign_urls(bucket_name, base_key, filenames, expires=3600, method='put_object') -> Dict[str, str]:
|
|
presign_url_map = {}
|
|
for filename in filenames:
|
|
key = f'{base_key}/{filename}'
|
|
url = s3.generate_presigned_url(method,
|
|
Params={'Bucket': bucket_name,
|
|
'Key': key,
|
|
},
|
|
ExpiresIn=expires)
|
|
presign_url_map[filename] = url
|
|
|
|
return presign_url_map
|
|
|
|
|
|
@tracer.capture_method
|
|
def generate_presign_url(bucket_name, key, expires=3600, method='put_object') -> Dict[str, str]:
|
|
return s3.generate_presigned_url(method,
|
|
Params={'Bucket': bucket_name,
|
|
'Key': key,
|
|
},
|
|
ExpiresIn=expires)
|
|
|
|
|
|
@tracer.capture_method
|
|
def load_json_from_s3(key: str):
|
|
key = key.replace(f"s3://{bucket_name}/", '')
|
|
response = s3.get_object(Bucket=bucket_name, Key=key)
|
|
json_file = response['Body'].read().decode('utf-8')
|
|
data = json.loads(json_file)
|
|
|
|
return data
|
|
|
|
|
|
def save_json_to_file(json_string: str, folder_path: str, file_name: str):
|
|
os.makedirs(folder_path, exist_ok=True)
|
|
file_path = os.path.join(folder_path, file_name)
|
|
|
|
with open(file_path, 'w') as file:
|
|
file.write(json.dumps(json_string))
|
|
|
|
return file_path
|
|
|
|
|
|
def get_pil_metadata(pil_image):
|
|
# Copy any text-only metadata
|
|
metadata = PngImagePlugin.PngInfo()
|
|
for key, value in pil_image.info.items():
|
|
if isinstance(key, str) and isinstance(value, str):
|
|
metadata.add_text(key, value)
|
|
|
|
return metadata
|
|
|
|
|
|
def encode_pil_to_base64(pil_image):
|
|
with BytesIO() as output_bytes:
|
|
pil_image.save(output_bytes, "PNG", pnginfo=get_pil_metadata(pil_image))
|
|
bytes_data = output_bytes.getvalue()
|
|
|
|
base64_str = str(base64.b64encode(bytes_data), "utf-8")
|
|
return "data:image/png;base64," + base64_str
|
|
|
|
|
|
def encode_no_json(obj):
|
|
if isinstance(obj, numpy.ndarray):
|
|
return encode_pil_to_base64(Image.fromarray(obj))
|
|
elif isinstance(obj, Image.Image):
|
|
return encode_pil_to_base64(obj)
|
|
elif isinstance(obj, enum.Enum):
|
|
return obj.value
|
|
elif hasattr(obj, '__dict__'):
|
|
return obj.__dict__
|
|
else:
|
|
logger.debug(f'may not able to json dumps {type(obj)}: {str(obj)}')
|
|
return str(obj)
|
|
|
|
|
|
@tracer.capture_method
|
|
def upload_json_to_s3(file_key: str, json_data: dict):
|
|
try:
|
|
file_key = file_key.replace(f"s3://{bucket_name}/", '')
|
|
s3.put_object(Body=json.dumps(json_data, indent=4, default=encode_no_json), Bucket=bucket_name, Key=file_key)
|
|
logger.info(f"Dictionary uploaded to s3://{bucket_name}/{file_key}")
|
|
except Exception as e:
|
|
logger.info(f"Error uploading dictionary: {e}")
|
|
|
|
|
|
@tracer.capture_method
|
|
def upload_file_to_s3(file_name, bucket, directory=None, object_name=None):
|
|
# If S3 object_name was not specified, use file_name
|
|
if object_name is None:
|
|
object_name = file_name
|
|
|
|
# Add the directory to the object_name
|
|
if directory:
|
|
object_name = f"{directory}/{object_name}"
|
|
|
|
# Upload the file
|
|
try:
|
|
s3_client.upload_file(file_name, bucket, object_name)
|
|
log_json(f"File {file_name} uploaded to {bucket}/{object_name}")
|
|
except Exception as e:
|
|
print(f"Error occurred while uploading {file_name} to {bucket}/{object_name}: {e}")
|
|
return False
|
|
return True
|
|
|
|
|
|
def split_s3_path(s3_path):
|
|
path_parts = s3_path.replace("s3://", "").split("/")
|
|
bucket = path_parts.pop(0)
|
|
key = "/".join(path_parts)
|
|
return bucket, key
|
|
|
|
|
|
@tracer.capture_method
|
|
def s3_scan_files(job: InferenceResult):
|
|
if not job.output_path and not job.temp_path:
|
|
job.status = "failed"
|
|
|
|
if job.status == 'fail':
|
|
job.status = "failed"
|
|
|
|
if job.output_path:
|
|
job.output_files = s3_scan_files_in_patch(job.output_path)
|
|
else:
|
|
job.output_files = []
|
|
job.output_path = ''
|
|
|
|
if job.temp_path:
|
|
job.temp_files = s3_scan_files_in_patch(job.temp_path)
|
|
else:
|
|
job.temp_files = []
|
|
job.temp_path = ''
|
|
|
|
return job
|
|
|
|
|
|
@tracer.capture_method
|
|
def s3_scan_files_in_patch(patch: str):
|
|
files = []
|
|
prefix = patch.replace(f"s3://{bucket_name}/", '')
|
|
for obj in s3_bucket.objects.filter(Prefix=prefix):
|
|
file = obj.key.replace(prefix, '')
|
|
if file:
|
|
files.append(file)
|
|
|
|
return files
|
|
|
|
|
|
def generate_presigned_url_for_key(key, expiration=3600):
|
|
key = key.replace(f"s3://{bucket_name}/", '')
|
|
|
|
return s3.generate_presigned_url(
|
|
'get_object',
|
|
Params={'Bucket': bucket_name, 'Key': key},
|
|
ExpiresIn=expiration
|
|
)
|
|
|
|
|
|
@tracer.capture_method
|
|
def generate_presigned_url_for_keys(prefix, keys, expiration=3600):
|
|
if not prefix or not keys:
|
|
return []
|
|
|
|
new_list = []
|
|
|
|
prefix = prefix.replace(f"s3://{bucket_name}/", '')
|
|
|
|
for key in keys:
|
|
new_list.append(generate_presigned_url_for_key(f"{prefix}{key}", expiration))
|
|
|
|
return new_list
|
|
|
|
|
|
@tracer.capture_method
|
|
def generate_presigned_url_for_job(job):
|
|
if 'output_path' in job and 'output_files' in job and job['output_path'] and job['output_files']:
|
|
job['output_files'] = generate_presigned_url_for_keys(job['output_path'], job['output_files'])
|
|
|
|
if 'temp_path' in job and 'temp_files' in job and job['temp_path'] and job['temp_files']:
|
|
job['temp_files'] = generate_presigned_url_for_keys(job['temp_path'], job['temp_files'])
|
|
|
|
return job
|
|
|
|
|
|
def endpoint_clean(ep: Endpoint):
|
|
try:
|
|
sagemaker.delete_endpoint_config(EndpointConfigName=ep.endpoint_name)
|
|
logger.info(f"Delete {ep.endpoint_name} endpoint config")
|
|
except Exception as e:
|
|
logger.error(e, exc_info=True)
|
|
|
|
try:
|
|
sagemaker.delete_model(ModelName=ep.endpoint_name)
|
|
logger.info(f"Delete {ep.endpoint_name} model")
|
|
except Exception as e:
|
|
logger.error(e, exc_info=True)
|
|
|
|
try:
|
|
s3_bucket.objects.filter(Prefix=f"endpoint-{esd_version}-{ep.endpoint_name}").delete()
|
|
logger.info(f"Delete {ep.endpoint_name} artifacts")
|
|
except Exception as e:
|
|
logger.error(e, exc_info=True)
|
|
|
|
try:
|
|
logs.delete_log_group(logGroupName=f"/aws/sagemaker/Endpoints/{ep.endpoint_name}")
|
|
logger.info(f"Delete {ep.endpoint_name} log group")
|
|
except Exception as e:
|
|
logger.error(e, exc_info=True)
|
|
|
|
try:
|
|
cloudwatch.delete_dashboards(DashboardNames=[ep.endpoint_name])
|
|
logger.info(f"Delete {ep.endpoint_name} dashboard")
|
|
except Exception as e:
|
|
logger.error(e, exc_info=True)
|
|
|
|
try:
|
|
response = cloudwatch.delete_alarms(AlarmNames=[f'{ep.endpoint_name}-HasBacklogWithoutCapacity-Alarm'], )
|
|
logger.info(f"delete_metric_alarm response: {response}")
|
|
except Exception as e:
|
|
logger.error(e, exc_info=True)
|
|
|
|
ddb_service.delete_item(
|
|
table=sagemaker_endpoint_table,
|
|
keys={'EndpointDeploymentJobId': ep.EndpointDeploymentJobId},
|
|
)
|
|
logger.info(f"Delete {ep.endpoint_name} endpoint from DDB")
|
|
|
|
|
|
def get_workflow_name(workflow, instance_type: str):
|
|
if not workflow:
|
|
return None
|
|
|
|
return f"{workflow} ({instance_type})"
|