stable-diffusion-aws-extension/aws_extension/cloud_api_manager/api_manager.py

315 lines
12 KiB
Python

import base64
import json
import logging
import requests
import utils
from aws_extension.auth_service.simple_cloud_auth import cloud_auth_manager
logger = logging.getLogger(__name__)
logger.setLevel(utils.LOGGING_LEVEL)
encode_type = "utf-8"
class CloudApiManager:
def __init__(self):
self.auth_manger = cloud_auth_manager
# todo: not sure how to get current login user's password from gradio
# todo: use username only for authorize checking for now only, e.g. user_token = username
def _get_headers_by_user(self, user_token):
if not user_token:
return {
'x-api-key': self.auth_manger.api_key,
'Content-Type': 'application/json',
}
_auth_token = f'Bearer {base64.b16encode(user_token.encode(encode_type)).decode(encode_type)}'
return {
'Authorization': _auth_token,
'x-api-key': self.auth_manger.api_key,
'Content-Type': 'application/json',
}
def sagemaker_deploy(self, endpoint_name, instance_type, initial_instance_count=1,
autoscaling_enabled=True, user_roles=None, user_token=""):
""" Create SageMaker endpoint for GPU inference.
Args:
instance_type (string): the ML compute instance type.
initial_instance_count (integer): Number of instances to launch initially.
Returns:
(None)
"""
# function code to call sagemaker deploy api
logger.debug(
f"start deploying instance type: {instance_type} with count {initial_instance_count} with autoscaling {autoscaling_enabled}............")
payload = {
"endpoint_name": endpoint_name,
"instance_type": instance_type,
"initial_instance_count": initial_instance_count,
"autoscaling_enabled": autoscaling_enabled,
'assign_to_roles': user_roles,
"creator": user_token,
}
deployment_url = f"{self.auth_manger.api_url}endpoints"
try:
response = requests.post(deployment_url, json=payload, headers=self._get_headers_by_user(user_token))
r = response.json()
logger.debug(f"response for rest api {r}")
return r['message']
except Exception as e:
logger.error(e)
return f"Failed to start endpoint deployment with exception: {e}"
def sagemaker_endpoint_delete(self, delete_endpoint_list, user_token=""):
logger.debug(f"start delete sagemaker endpoint delete function")
logger.debug(f"delete endpoint list: {delete_endpoint_list}")
delete_endpoint_list = [item.split('+')[0] for item in delete_endpoint_list]
logger.debug(f"delete endpoint list: {delete_endpoint_list}")
payload = {
"endpoint_name_list": delete_endpoint_list,
"username": user_token,
}
deployment_url = f"{self.auth_manger.api_url}endpoints"
try:
response = requests.delete(deployment_url, json=payload, headers=self._get_headers_by_user(user_token))
response = response.json()
logger.debug(f"response for rest api {response}")
return response['message']
except Exception as e:
logger.error(e)
return f"Failed to delete sagemaker endpoint with exception: {e}"
def list_all_sagemaker_endpoints_raw(self, username=None, user_token=""):
if self.auth_manger.enableAuth and not user_token:
return []
if not self.auth_manger.api_url:
return []
response = requests.get(f'{self.auth_manger.api_url}endpoints',
params={
'username': username,
},
headers=self._get_headers_by_user(user_token))
response.raise_for_status()
r = response.json()
if not r or r['statusCode'] != 200:
logger.info(f"The API response is empty for update_sagemaker_endpoints().{r['message']}")
return []
return r['data']['endpoints']
def list_all_sagemaker_endpoints(self, username=None, user_token=""):
try:
if self.auth_manger.enableAuth and not user_token:
return []
response = requests.get(f'{self.auth_manger.api_url}endpoints',
params={
'username': username,
},
headers=self._get_headers_by_user(user_token))
response.raise_for_status()
r = response.json()
if not r:
logger.info("The API response is empty for update_sagemaker_endpoints().")
return []
sagemaker_raw_endpoints = []
for obj in r['data']['endpoints']:
if "EndpointDeploymentJobId" in obj:
if "endpoint_name" in obj:
endpoint_name = obj["endpoint_name"]
endpoint_status = obj["endpoint_status"]
else:
endpoint_name = obj["EndpointDeploymentJobId"]
endpoint_status = obj["status"]
# Skip if status is 'Deleted'
if endpoint_status == 'Deleted':
continue
# Compatible with fields used in older versions
if obj["status"] == 'deleted':
continue
if "endTime" in obj:
endpoint_time = obj["endTime"]
else:
endpoint_time = "N/A"
endpoint_info = f"{endpoint_name}+{endpoint_status}+{endpoint_time}"
sagemaker_raw_endpoints.append(endpoint_info)
# Sort the list based on completeTime in descending order
return sorted(sagemaker_raw_endpoints, key=lambda x: x.split('+')[-1], reverse=True)
except Exception as e:
logger.error(f"An error occurred while updating SageMaker endpoints: {e}")
return []
def get_user_by_username(self, username='', user_token='', show_password=False):
if not self.auth_manger.enableAuth:
return {
'users': []
}
raw_resp = requests.get(url=f'{self.auth_manger.api_url}users',
params={
'username': username,
'show_password': show_password
},
headers=self._get_headers_by_user(user_token))
raw_resp.raise_for_status()
logger.debug(raw_resp.json())
resp = raw_resp.json()['data']
return resp['users'][0]
def list_users(self, user_token=""):
if not self.auth_manger.enableAuth:
return {
'users': []
}
raw_resp = requests.get(url=f'{self.auth_manger.api_url}users',
params={},
headers=self._get_headers_by_user(user_token))
raw_resp.raise_for_status()
return raw_resp.json()['data']
def list_roles(self, user_token=""):
if not self.auth_manger.enableAuth:
return {
'roles': []
}
raw_resp = requests.get(url=f'{self.auth_manger.api_url}roles', headers=self._get_headers_by_user(user_token))
raw_resp.raise_for_status()
return raw_resp.json()['data']
def upsert_role(self, role_name, permissions, creator, user_token=""):
if not self.auth_manger.enableAuth:
return {}
payload = {
"role_name": role_name,
"permissions": permissions,
"creator": creator
}
raw_resp = requests.post(f'{cloud_auth_manager.api_url}roles', json=payload, headers=self._get_headers_by_user(user_token))
raw_resp.raise_for_status()
resp = raw_resp.json()
if raw_resp.status_code != 200:
raise Exception(resp['message'])
return True
def upsert_user(self, username, password, roles, creator, initial=False, user_token=""):
if not self.auth_manger.enableAuth and not initial:
return {}
if not password or len(password) < 1:
raise Exception('password should not be none')
payload = {
"initial": initial,
"username": username,
"password": password,
"roles": roles,
"creator": creator,
}
if initial:
cloud_auth_manager.refresh()
raw_resp = requests.post(f'{cloud_auth_manager.api_url}users',
json=payload,
headers=self._get_headers_by_user(user_token)
)
raw_resp.raise_for_status()
resp = raw_resp.json()
if raw_resp.status_code != 200:
raise Exception(resp['message'])
cloud_auth_manager.update_gradio_auth()
return True
def delete_user(self, username, user_token=""):
if not self.auth_manger.enableAuth:
return {}
payload = {
"user_name_list": [username]
}
raw_resp = requests.delete(f'{cloud_auth_manager.api_url}users',
json=payload,
headers=self._get_headers_by_user(user_token))
raw_resp.raise_for_status()
resp = raw_resp.json()
if raw_resp.status_code != 200:
raise Exception(resp['message'])
return True
def list_models_on_cloud(self, username, user_token="", types='Stable-diffusion', status='Active'):
if not self.auth_manger.enableAuth:
return []
raw_resp = requests.get(url=f'{self.auth_manger.api_url}checkpoints', params={
'username': username,
'types': types,
'status': status
}, headers=self._get_headers_by_user(user_token))
raw_resp.raise_for_status()
checkpoints = []
resp = raw_resp.json()['data']
for ckpt in resp['checkpoints']:
if not ckpt or 'name' not in ckpt or not ckpt['name']:
continue
for name in ckpt['name']:
if name not in checkpoints:
checkpoints.append({
'name': name,
'id': ckpt['id'],
's3Location': ckpt['s3Location'],
'type': ckpt['type'],
'status': ckpt['status'],
'created': float(ckpt['created']),
'allowed_roles_or_users': ckpt['allowed_roles_or_users'],
})
return checkpoints
def list_all_inference_jobs_on_cloud(self, username, user_token=""):
if not self.auth_manger.enableAuth:
return []
raw_resp = requests.get(url=f'{self.auth_manger.api_url}inferences', params={
'username': username,
}, headers=self._get_headers_by_user(user_token))
raw_resp.raise_for_status()
resp = raw_resp.json()
return resp['data']['inferences']
def get_dataset_items_from_dataset(self, dataset_name, user_token=""):
if not self.auth_manger.enableAuth:
return []
raw_response = requests.get(url=f'{self.auth_manger.api_url}datasets/{dataset_name}', headers=self._get_headers_by_user(user_token))
raw_response.raise_for_status()
# todo: the s3 presign url is not ready as content type to img
resp = raw_response.json()['data']
return resp
api_manager = CloudApiManager()