stable-diffusion-aws-extension/middleware_api/lambda/check_endpoint_deployment/app.py

207 lines
8.9 KiB
Python

import boto3
import os
from datetime import datetime
DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME = os.environ.get('DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME')
sagemaker = boto3.client('sagemaker')
ddb_client = boto3.resource('dynamodb')
endpoint_deployment_table = ddb_client.Table(DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME)
def lambda_handler(event, context):
# Parse the input data
print(f"event is {event}")
event_payload = event["Payload"]
stage = event_payload['stage']
endpoint_deployment_job_id = event_payload['endpoint_deployment_id']
endpoint_name = event_payload['endpoint_name']
if stage == 'Training':
print("Status check for training not implemented yet!")
elif stage == 'Deployment':
name = event_payload['endpoint_name']
endpoint_details = describe_endpoint(name)
status = endpoint_details['EndpointStatus']
if status == 'InService':
current_time = str(datetime.now())
event_payload['message'] = 'Deployment completed for endpoint "{}".'.format(name)
check_and_enable_autoscaling(DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME, {'EndpointDeploymentJobId': endpoint_deployment_job_id}, 'autoscaling', endpoint_name, 'prod')
update_endpoint_job_table(endpoint_deployment_job_id,'endpoint_name', endpoint_name)
update_endpoint_job_table(endpoint_deployment_job_id,'endpoint_status', status)
update_endpoint_job_table(endpoint_deployment_job_id,'endTime', current_time)
update_endpoint_job_table(endpoint_deployment_job_id,'status', 'success')
elif status == 'Creating':
update_endpoint_job_table(endpoint_deployment_job_id,'endpoint_name', endpoint_name)
update_endpoint_job_table(endpoint_deployment_job_id,'endpoint_status', status)
elif status == 'Failed':
failure_reason = endpoint_details['FailureReason']
event_payload['message'] = 'Deployment failed for endpoint "{}". {}'.format(name, failure_reason)
update_endpoint_job_table(endpoint_deployment_job_id,'endpoint_name', endpoint_name)
update_endpoint_job_table(endpoint_deployment_job_id,'endpoint_status', status)
update_endpoint_job_table(endpoint_deployment_job_id,'status', 'failed')
elif status == 'RollingBack':
event_payload['message'] = 'Deployment failed for endpoint "{}", rolling back to previously deployed version.'.format(name)
update_endpoint_job_table(endpoint_deployment_job_id,'endpoint_name', endpoint_name)
update_endpoint_job_table(endpoint_deployment_job_id,'endpoint_status', status)
event_payload['status'] = status
return event_payload
def update_endpoint_job_table(endpoint_deployment_job_id, key, value):
# Update the inference DDB for the job status
response = endpoint_deployment_table.get_item(
Key={
"EndpointDeploymentJobId": endpoint_deployment_job_id,
})
endpoint_resp = response['Item']
if not endpoint_resp:
raise Exception(f"Failed to get the endpoint deployment job item with endpoint deployment job id: {endpoint_deployment_job_id}")
response = endpoint_deployment_table.update_item(
Key={
"EndpointDeploymentJobId": endpoint_deployment_job_id,
},
UpdateExpression=f"set #k = :r",
ExpressionAttributeNames={'#k': key},
ExpressionAttributeValues={':r': value},
ReturnValues="UPDATED_NEW"
)
def get_ddb_value(table_name, key, field_name):
dynamodb = boto3.resource('dynamodb')
table = dynamodb.Table(table_name)
try:
response = table.get_item(Key=key)
except Exception as e:
print(str(e))
return None
else:
item = response['Item']
return item.get(field_name, None)
def enable_autoscaling(endpoint_name, variant_name, low_value, high_value):
client = boto3.client('application-autoscaling')
# Register scalable target
response = client.register_scalable_target(
ServiceNamespace='sagemaker',
ResourceId='endpoint/' + endpoint_name + '/variant/' + variant_name,
ScalableDimension='sagemaker:variant:DesiredInstanceCount',
MinCapacity=low_value,
MaxCapacity=high_value,
)
# Define scaling policy
"""
response = client.put_scaling_policy(
PolicyName='StableDiffusionDefaultScalingPolicy',
ServiceNamespace='sagemaker',
ResourceId='endpoint/' + endpoint_name + '/variant/' + variant_name,
ScalableDimension='sagemaker:variant:DesiredInstanceCount',
PolicyType='TargetTrackingScaling',
TargetTrackingScalingPolicyConfiguration={
'TargetValue': 2.0,
'PredefinedMetricSpecification': {
'PredefinedMetricType': 'SageMakerVariantInvocationsPerInstance',
},
'ScaleInCooldown': 300,
'ScaleOutCooldown': 300
}
)
"""
response = client.put_scaling_policy(
PolicyName='CPUUtil-ScalingPolicy',
ServiceNamespace='sagemaker',
ResourceId='endpoint/' + endpoint_name + '/variant/' + variant_name,
ScalableDimension='sagemaker:variant:DesiredInstanceCount',
PolicyType='TargetTrackingScaling',
TargetTrackingScalingPolicyConfiguration={
'TargetValue': 50.0,
'CustomizedMetricSpecification':
{
'MetricName': 'CPUUtilization',
'Namespace': '/aws/sagemaker/Endpoints',
'Dimensions': [
{'Name': 'EndpointName', 'Value': endpoint_name },
{'Name': 'VariantName','Value': 'prod'}
],
'Statistic': 'Average', # Possible - 'Statistic': 'Average'|'Minimum'|'Maximum'|'SampleCount'|'Sum'
'Unit': 'Percent'
},
'ScaleInCooldown': 300,
'ScaleOutCooldown': 300
}
)
step_policy_response = client.put_scaling_policy(
PolicyName="HasBacklogWithoutCapacity-ScalingPolicy",
ServiceNamespace="sagemaker", # The namespace of the service that provides the resource.
ResourceId='endpoint/' + endpoint_name + '/variant/' + variant_name,
ScalableDimension="sagemaker:variant:DesiredInstanceCount", # SageMaker supports only Instance Count
PolicyType="StepScaling", # 'StepScaling' or 'TargetTrackingScaling'
StepScalingPolicyConfiguration={
"AdjustmentType": "ChangeInCapacity", # Specifies whether the ScalingAdjustment value in the StepAdjustment property is an absolute number or a percentage of the current capacity.
"MetricAggregationType": "Average", # The aggregation type for the CloudWatch metrics.
"Cooldown": 300, # The amount of time, in seconds, to wait for a previous scaling activity to take effect.
"StepAdjustments": # A set of adjustments that enable you to scale based on the size of the alarm breach.
[
{
"MetricIntervalLowerBound": 0,
"ScalingAdjustment": 1
}
]
},
)
cw_client = boto3.client('cloudwatch')
cw_client.put_metric_alarm(
AlarmName='stable-diffusion-hasbacklogwithoutcapacity-alarm',
MetricName='HasBacklogWithoutCapacity',
Namespace='AWS/SageMaker',
Statistic='Average',
EvaluationPeriods= 2,
DatapointsToAlarm= 2,
Threshold= 1,
ComparisonOperator='GreaterThanOrEqualToThreshold',
TreatMissingData='missing',
Dimensions=[
{ 'Name':'EndpointName', 'Value': endpoint_name },
],
Period= 60,
AlarmActions=[step_policy_response['PolicyARN']]
)
print(f"Autoscaling has been enabled for the endpoint: {endpoint_name}")
def check_and_enable_autoscaling(table_name, key, field_name, endpoint_name, variant_name):
autoscaling_enabled = get_ddb_value(table_name, key, field_name)
if str(autoscaling_enabled) == 'True':
max_number = get_ddb_value(table_name, key, 'max_instance_number')
if max_number.isdigit():
enable_autoscaling(endpoint_name, variant_name, 0, int(max_number))
else:
print(f"the max_number field is not digit, just fallback to 1")
enable_autoscaling(endpoint_name, variant_name, 0, 1)
else:
print(f'autoscaling_enabled is {autoscaling_enabled}, no need to enable autoscaling')
def describe_endpoint(name):
""" Describe SageMaker endpoint identified by input name.
Args:
name (string): Name of SageMaker endpoint to describe.
Returns:
(dict)
Dictionary containing metadata and details about the status of the endpoint.
"""
try:
response = sagemaker.describe_endpoint(
EndpointName=name
)
except Exception as e:
print(e)
print('Unable to describe endpoint.')
raise(e)
return response