stable-diffusion-aws-extension/middleware_api/endpoints/endpoint_event.py

327 lines
13 KiB
Python

import json
import logging
import os
from datetime import datetime
import boto3
from aws_lambda_powertools import Tracer
from common.util import record_seconds_metrics, endpoint_clean
from inferences.inference_libs import update_table_by_pk
from libs.data_types import Endpoint
from libs.enums import EndpointStatus, EndpointType, ServiceType
from libs.utils import get_endpoint_by_name
lambda_client = boto3.client('lambda')
tracer = Tracer()
sagemaker_endpoint_table = os.environ.get('ENDPOINT_TABLE_NAME')
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
autoscaling_client = boto3.client('application-autoscaling')
cw_client = boto3.client('cloudwatch')
sagemaker = boto3.client('sagemaker')
esd_version = os.environ.get("ESD_VERSION")
cool_down_period = 15 * 60 # 15 minutes
s3_bucket_name = os.environ.get('S3_BUCKET_NAME')
s3 = boto3.resource('s3')
bucket = s3.Bucket(s3_bucket_name)
aws_region = os.environ.get('AWS_REGION')
# lambda: handle sagemaker events
@tracer.capture_lambda_handler
def handler(event, context):
logger.info(json.dumps(event))
endpoint_name = event['detail']['EndpointName']
endpoint_status = event['detail']['EndpointStatus']
current_time = datetime.now().isoformat()
try:
endpoint = get_endpoint_by_name(endpoint_name)
business_status = get_business_status(endpoint_status)
update_endpoint_field(endpoint, 'endpoint_status', business_status)
if business_status == EndpointStatus.UPDATING.value:
update_endpoint_field(endpoint, 'startTime', current_time)
if business_status == EndpointStatus.IN_SERVICE.value:
update_endpoint_field(endpoint, 'endTime', current_time)
if endpoint.service_type == 'sd':
service_type = ServiceType.SD.value
else:
service_type = ServiceType.Comfy.value
record_seconds_metrics(start_time=endpoint.startTime,
metric_name='EndpointReadySeconds',
service=service_type)
# if it is the first time in service
if not endpoint.endTime:
check_and_enable_autoscaling(endpoint, 'prod')
# update the instance count if the endpoint is not deleting or deleted
if business_status not in [EndpointStatus.DELETING.value, EndpointStatus.DELETED.value]:
status = sagemaker.describe_endpoint(EndpointName=endpoint_name)
logger.info(f"Endpoint status: {status}")
if 'ProductionVariants' in status:
instance_count = status['ProductionVariants'][0]['CurrentInstanceCount']
update_endpoint_field(endpoint, 'current_instance_count', instance_count)
else:
endpoint_clean(endpoint)
if business_status == EndpointStatus.FAILED.value:
update_endpoint_field(endpoint, 'error', event['FailureReason'])
except Exception as e:
logger.error(e, exc_info=True)
return {'statusCode': 200}
def check_and_enable_autoscaling(ep: Endpoint, variant_name):
if ep.autoscaling:
enable_autoscaling(ep, variant_name)
else:
logger.info(f'no need to enable autoscaling')
@tracer.capture_method
def enable_autoscaling(ep: Endpoint, variant_name):
tracer.put_annotation("variant_name", variant_name)
max_instance_number = int(ep.max_instance_number)
min_instance_number = 0
if ep.endpoint_type == EndpointType.RealTime.value:
min_instance_number = 1
if ep.min_instance_number is not None:
min_instance_number = int(ep.min_instance_number)
# Register scalable target
response = autoscaling_client.register_scalable_target(
ServiceNamespace='sagemaker',
ResourceId='endpoint/' + ep.endpoint_name + '/variant/' + variant_name,
ScalableDimension='sagemaker:variant:DesiredInstanceCount',
MinCapacity=min_instance_number,
MaxCapacity=max_instance_number,
)
logger.info(f"Register scalable target response: {response}")
if ep.endpoint_type == EndpointType.Async.value:
enable_autoscaling_async(ep, variant_name)
if ep.endpoint_type == EndpointType.RealTime.value:
enable_autoscaling_real_time(ep, variant_name)
def enable_autoscaling_async(ep: Endpoint, variant_name):
target_value = 3
# Define scaling policy
response = autoscaling_client.put_scaling_policy(
PolicyName=f"{ep.endpoint_name}-Invocations-ScalingPolicy",
ServiceNamespace="sagemaker", # The namespace of the AWS service that provides the resource.
ResourceId='endpoint/' + ep.endpoint_name + '/variant/' + variant_name, # Endpoint name
ScalableDimension="sagemaker:variant:DesiredInstanceCount", # SageMaker supports only Instance Count
PolicyType="TargetTrackingScaling", # 'StepScaling'|'TargetTrackingScaling'
TargetTrackingScalingPolicyConfiguration={
"TargetValue": target_value,
# The target value for the metric. - here the metric is - SageMakerVariantInvocationsPerInstance
"CustomizedMetricSpecification": {
"MetricName": "ApproximateBacklogSizePerInstance",
"Namespace": "AWS/SageMaker",
"Dimensions": [{"Name": "EndpointName", "Value": ep.endpoint_name}],
"Statistic": "Average",
},
"ScaleInCooldown": 180,
# The cooldown period helps you prevent your Auto Scaling group from launching or terminating
"ScaleOutCooldown": 60
# ScaleOutCooldown - The amount of time, in seconds, after a scale out activity completes before another
# scale out activity can start.
},
)
logger.info(f"Put scaling policy response")
logger.info(json.dumps(response))
alarms = response.get('Alarms')
for alarm in alarms:
alarm_name = alarm.get('AlarmName')
logger.info(f"Alarm name: {alarm_name}")
response = cw_client.describe_alarms(
AlarmNames=[alarm_name]
)
logger.info(f"Describe alarm response")
logger.info(response)
comparison_operator = response.get('MetricAlarms')[0]['ComparisonOperator']
if comparison_operator == "LessThanThreshold":
period = cool_down_period # 15 minutes
evaluation_periods = 4
datapoints_to_alarm = 4
target_value = 1
else:
period = 30
evaluation_periods = 1
datapoints_to_alarm = 1
target_value = 3
response = cw_client.put_metric_alarm(
AlarmName=alarm_name,
Namespace='AWS/SageMaker',
MetricName='ApproximateBacklogSizePerInstance',
Statistic="Average",
Period=period,
EvaluationPeriods=evaluation_periods,
DatapointsToAlarm=datapoints_to_alarm,
Threshold=target_value,
ComparisonOperator=comparison_operator,
AlarmActions=response.get('MetricAlarms')[0]['AlarmActions'],
Dimensions=[{'Name': 'EndpointName', 'Value': ep.endpoint_name}]
)
logger.info(f"Put metric alarm response")
logger.info(response)
step_policy_response = autoscaling_client.put_scaling_policy(
PolicyName=f"{ep.endpoint_name}-HasBacklogWithoutCapacity-ScalingPolicy",
ServiceNamespace="sagemaker", # The namespace of the service that provides the resource.
ResourceId='endpoint/' + ep.endpoint_name + '/variant/' + variant_name,
ScalableDimension="sagemaker:variant:DesiredInstanceCount", # SageMaker supports only Instance Count
PolicyType="StepScaling", # 'StepScaling' or 'TargetTrackingScaling'
StepScalingPolicyConfiguration={
"AdjustmentType": "ChangeInCapacity",
# Specifies whether the ScalingAdjustment value in the StepAdjustment property is an absolute number or a
# percentage of the current capacity.
"MetricAggregationType": "Average", # The aggregation type for the CloudWatch metrics.
"Cooldown": 180, # The amount of time, in seconds, to wait for a previous scaling activity to take effect.
"StepAdjustments": # A set of adjustments that enable you to scale based on the size of the alarm breach.
[
{
"MetricIntervalLowerBound": 0,
"ScalingAdjustment": 1
}
]
},
)
logger.info(f"Put step scaling policy response: {step_policy_response}")
cw_client.put_metric_alarm(
AlarmName=f'{ep.endpoint_name}-HasBacklogWithoutCapacity-Alarm',
MetricName='HasBacklogWithoutCapacity',
Namespace='AWS/SageMaker',
Statistic='Average',
Period=30,
EvaluationPeriods=1,
DatapointsToAlarm=1,
Threshold=1,
ComparisonOperator='GreaterThanOrEqualToThreshold',
TreatMissingData='missing',
Dimensions=[
{'Name': 'EndpointName', 'Value': ep.endpoint_name},
],
AlarmActions=[step_policy_response['PolicyARN']]
)
logger.info(f"Put metric alarm response: {step_policy_response}")
logger.info(f"Autoscaling has been enabled for the endpoint: {ep.endpoint_name}")
@tracer.capture_method
def enable_autoscaling_real_time(ep: Endpoint, variant_name):
tracer.put_annotation("variant_name", variant_name)
target_value = 5
# Define scaling policy
response = autoscaling_client.put_scaling_policy(
PolicyName=f"{ep.endpoint_name}-Invocations-ScalingPolicy",
ServiceNamespace="sagemaker", # The namespace of the AWS service that provides the resource.
ResourceId='endpoint/' + ep.endpoint_name + '/variant/' + variant_name, # Endpoint name
ScalableDimension="sagemaker:variant:DesiredInstanceCount", # SageMaker supports only Instance Count
PolicyType="TargetTrackingScaling", # 'StepScaling'|'TargetTrackingScaling'
TargetTrackingScalingPolicyConfiguration={
"TargetValue": target_value,
"PredefinedMetricSpecification":
{
"PredefinedMetricType": "SageMakerVariantInvocationsPerInstance"
},
"ScaleInCooldown": 180,
# The cooldown period helps you prevent your Auto Scaling group from launching or terminating
"ScaleOutCooldown": 60
# ScaleOutCooldown - The amount of time, in seconds, after a scale out activity completes before another
# scale out activity can start.
},
)
logger.info(f"Put scaling policy response")
logger.info(json.dumps(response))
alarms = response.get('Alarms')
for alarm in alarms:
alarm_name = alarm.get('AlarmName')
logger.info(f"Alarm name: {alarm_name}")
response = cw_client.describe_alarms(
AlarmNames=[alarm_name]
)
logger.info(f"Describe alarm response")
logger.info(response)
comparison_operator = response.get('MetricAlarms')[0]['ComparisonOperator']
if comparison_operator == "LessThanThreshold":
period = cool_down_period # 15 minutes
evaluation_periods = 4
datapoints_to_alarm = 4
target_value = 1
else:
period = 30
evaluation_periods = 1
datapoints_to_alarm = 1
target_value = 5
response = cw_client.put_metric_alarm(
AlarmName=alarm_name,
Namespace='AWS/SageMaker',
MetricName='InvocationsPerInstance',
Statistic="Sum",
Period=period,
EvaluationPeriods=evaluation_periods,
DatapointsToAlarm=datapoints_to_alarm,
Threshold=target_value,
ComparisonOperator=comparison_operator,
AlarmActions=response.get('MetricAlarms')[0]['AlarmActions'],
Dimensions=[
{'Name': 'EndpointName', 'Value': ep.endpoint_name},
{'Name': 'VariantName', 'Value': 'prod'},
]
)
logger.info(f"Put metric alarm response")
logger.info(response)
logger.info(f"Autoscaling has been enabled for the endpoint: {ep.endpoint_name}")
def update_endpoint_field(ep: Endpoint, field_name, field_value):
update_table_by_pk(
table_name=sagemaker_endpoint_table,
pk='EndpointDeploymentJobId',
id=ep.EndpointDeploymentJobId,
key=field_name,
value=field_value
)
def get_business_status(status):
"""
Convert SageMaker endpoint status to business status
:param status: EventBridge event status(upper case)
:return: business status
"""
switcher = {
"IN_SERVICE": EndpointStatus.IN_SERVICE.value,
"CREATING": EndpointStatus.CREATING.value,
"DELETED": EndpointStatus.DELETED.value,
"FAILED": EndpointStatus.FAILED.value,
"UPDATING": EndpointStatus.UPDATING.value,
"DELETING": EndpointStatus.DELETING.value,
}
return switcher.get(status, status)