stable-diffusion-aws-extension/aws_extension/cloud_api_manager/api_manager.py

531 lines
19 KiB
Python

import logging
import requests
import utils
from aws_extension.auth_service.simple_cloud_auth import cloud_auth_manager, Admin_Role
from aws_extension.cloud_api_manager.api import api
from utils import has_config
logger = logging.getLogger(__name__)
logger.setLevel(utils.LOGGING_LEVEL)
string_separator = "___"
last_evaluated_key = {}
class CloudApiManager:
def __init__(self):
self.auth_manger = cloud_auth_manager
def _get_headers_by_user(self, username):
if not username:
return {
'x-api-key': self.auth_manger.api_key,
'Content-Type': 'application/json',
}
return {
'username': username,
'x-api-key': self.auth_manger.api_key,
'Content-Type': 'application/json',
}
def sagemaker_endpoint_delete(self, delete_endpoint_list, username=""):
if not delete_endpoint_list:
return "No endpoint to delete"
logger.debug(f"start delete sagemaker endpoint delete function")
logger.debug(f"delete endpoint list: {delete_endpoint_list}")
delete_endpoint_list = [item.split('+')[0] for item in delete_endpoint_list]
logger.debug(f"delete endpoint list: {delete_endpoint_list}")
payload = {
"endpoint_name_list": delete_endpoint_list,
}
deployment_url = f"{self.auth_manger.api_url}endpoints"
try:
resp = requests.delete(deployment_url, json=payload, headers=self._get_headers_by_user(username))
if resp.status_code != 204:
raise Exception(resp.json()['message'])
return "Delete Endpoint Successfully"
except Exception as e:
logger.error(e)
return f"Failed to delete sagemaker endpoint with exception: {e}"
def trains_delete(self, list, username=""):
if not list:
return "No trains to delete"
payload = {
"training_id_list": list,
}
url = f"{self.auth_manger.api_url}trainings"
try:
resp = requests.delete(url, json=payload, headers=self._get_headers_by_user(username))
if resp.status_code != 204:
raise Exception(resp.json()['message'])
return "Delete Trainings Successfully"
except Exception as e:
logger.error(e)
return f"Failed to delete trainings with exception: {e}"
def sagemaker_deploy(self, endpoint_name,
endpoint_type,
instance_type,
initial_instance_count=1,
custom_docker_image_uri="",
custom_extensions="",
autoscaling_enabled=True,
user_roles=None,
min_instance_number=1,
username=""):
""" Create SageMaker endpoint for GPU inference.
Args:
instance_type (string): the ML compute instance type.
initial_instance_count (integer): Number of instances to launch initially.
Returns:
(None)
"""
# function code to call sagemaker deploy api
logger.debug(
f"start deploying instance type: {instance_type} with count {initial_instance_count} with autoscaling {autoscaling_enabled}............")
payload = {
"endpoint_name": endpoint_name,
"service_type": 'sd',
"endpoint_type": endpoint_type,
"instance_type": instance_type,
"initial_instance_count": initial_instance_count,
'min_instance_number': min_instance_number,
# use initial_instance_count for user experience
'max_instance_number': initial_instance_count,
"autoscaling_enabled": autoscaling_enabled,
"custom_docker_image_uri": custom_docker_image_uri,
"custom_extensions": custom_extensions,
'assign_to_roles': user_roles,
"creator": username,
}
deployment_url = f"{self.auth_manger.api_url}endpoints"
try:
response = requests.post(deployment_url, json=payload, headers=self._get_headers_by_user(username))
r = response.json()
logger.debug(f"response for rest api {r}")
return r['message']
except Exception as e:
logger.error(e)
return f"Failed to start endpoint deployment with exception: {e}"
def ckpts_delete(self, ckpts, user_token=""):
logger.debug(f"ckpts: {ckpts}")
data = {
"checkpoint_id_list": ckpts,
}
try:
api.set_username(user_token)
resp = api.delete_checkpoints(data=data)
if resp.status_code != 204:
raise Exception(resp.json()['message'])
return "Delete Checkpoints Successfully"
except Exception as e:
logger.error(e)
return f"Failed to delete checkpoint with exception: {e}"
def ckpt_rename(self, ckpt_id, name, user_token=""):
data = {
"name": name,
}
try:
api.set_username(user_token)
resp = api.update_checkpoint(checkpoint_id=ckpt_id, data=data)
return resp.json()['message']
except Exception as e:
logger.error(e)
return f"Failed to rename checkpoint with exception: {e}"
def list_all_train_jobs_raw(self, username=None, last_key=None):
if self.auth_manger.enableAuth and not username:
return [], ''
if not self.auth_manger.api_url:
return [], ''
response = requests.get(f'{self.auth_manger.api_url}trainings',
params={
'username': username,
'exclusive_start_key': last_key,
'limit': 10,
},
headers=self._get_headers_by_user(username))
r = response.json()
if not r or r['statusCode'] != 200:
logger.error(f"list_trainings: {r}")
return []
last_ek = ''
if 'last_evaluated_key' in r['data']:
last_ek = r['data']['last_evaluated_key']
return r['data']['trainings'], last_ek
def list_all_sagemaker_endpoints_raw(self, username=None, user_token="", last_key: str = ""):
if self.auth_manger.enableAuth and not user_token:
return [], ''
if not self.auth_manger.api_url:
return [], ''
response = requests.get(f'{self.auth_manger.api_url}endpoints',
params={
'username': username,
'exclusive_start_key': last_key,
'limit': 10,
},
headers=self._get_headers_by_user(user_token))
if response.status_code != 200:
logger.error(f"list_endpoints: {response.json()}")
return [], ''
r = response.json()
if not r or r['statusCode'] != 200:
logger.info(f"The API response is empty for list_endpoints().{r['message']}")
return [], ''
last_ek = ''
if 'last_evaluated_key' in r['data']:
last_ek = r['data']['last_evaluated_key']
return r['data']['endpoints'], last_ek
def list_all_sagemaker_endpoints(self, username=None, user_token=""):
try:
if self.auth_manger.enableAuth and not user_token:
return []
if not has_config():
return []
response = requests.get(f'{self.auth_manger.api_url}endpoints',
params={
'username': username,
},
headers=self._get_headers_by_user(user_token))
response.raise_for_status()
r = response.json()
if not r:
logger.info("The API response is empty for update_sagemaker_endpoints().")
return []
sagemaker_raw_endpoints = []
for obj in r['data']['endpoints']:
if "EndpointDeploymentJobId" in obj:
if "endpoint_name" in obj:
endpoint_name = obj["endpoint_name"]
endpoint_status = obj["endpoint_status"]
else:
endpoint_name = obj["EndpointDeploymentJobId"]
endpoint_status = obj["status"]
if "endTime" in obj:
endpoint_time = obj["endTime"]
else:
endpoint_time = "N/A"
endpoint_info = f"{endpoint_name}+{endpoint_status}+{endpoint_time}"
sagemaker_raw_endpoints.append(endpoint_info)
# Sort the list based on completeTime in descending order
return sorted(sagemaker_raw_endpoints, key=lambda x: x.split('+')[-1], reverse=True)
except Exception as e:
logger.error(f"An error occurred while updating SageMaker endpoints: {e}")
return []
def list_all_ckpts(self, username=None, user_token=""):
try:
if self.auth_manger.enableAuth and not user_token:
return []
if not has_config():
return []
params = {
'username': username,
'per_page': 200,
}
api.set_username(username)
response = api.list_checkpoints(params=params)
r = response.json()
if not r:
logger.info("The API response is empty for update_sagemaker_endpoints().")
return []
ckpts_list = []
for ckpt in r['data']['checkpoints']:
if 'name' in ckpt and ckpt['name']:
ckpt_name = ckpt['name'][0]
else:
ckpt_name = 'None'
option_value = f"{ckpt_name}{string_separator}{ckpt['status']}{string_separator}{ckpt['id']}"
ckpts_list.append(option_value)
return sorted(ckpts_list, key=lambda x: x.split('+')[-1], reverse=True)
except Exception as e:
logger.error(f"list_all_ckpts: {e}")
return []
def get_user_by_username(self, username='', h_username='', show_password=False):
if not self.auth_manger.enableAuth:
return {
'users': []
}
raw_resp = requests.get(url=f'{self.auth_manger.api_url}users',
params={
'username': username,
'show_password': show_password
},
headers=self._get_headers_by_user(h_username))
if raw_resp.status_code != 200:
logger.error(f"list_users: {raw_resp.json()}")
return {
'users': []
}
logger.debug(raw_resp.json())
resp = raw_resp.json()['data']
return resp['users'][0]
def list_users(self, username=""):
if not self.auth_manger.enableAuth:
return {
'users': []
}
raw_resp = requests.get(url=f'{self.auth_manger.api_url}users',
params={},
headers=self._get_headers_by_user(username))
if raw_resp.status_code != 200:
logger.error(f"list_users: {raw_resp.json()}")
return {
'users': []
}
return raw_resp.json()['data']
def list_roles(self, username=""):
if not self.auth_manger.enableAuth or not has_config():
return {
'roles': []
}
raw_resp = requests.get(url=f'{self.auth_manger.api_url}roles', headers=self._get_headers_by_user(username))
if raw_resp.status_code != 200:
logger.error(f"list_roles: {raw_resp.json()}")
return {
'roles': []
}
return raw_resp.json()['data']
def upsert_role(self, role_name, permissions, creator):
if not self.auth_manger.enableAuth:
return {}
payload = {
"role_name": role_name,
"permissions": permissions,
"creator": creator
}
raw_resp = requests.post(f'{cloud_auth_manager.api_url}roles', json=payload,
headers=self._get_headers_by_user(creator))
resp = raw_resp.json()
if raw_resp.status_code != 200 and raw_resp.status_code != 201:
logger.error(f"upsert_role: {resp}")
raise Exception(resp['message'])
return True
def upsert_user(self, username, password, roles, creator, initial=False):
if not self.auth_manger.enableAuth and not initial:
return {}
if not password or len(password) < 1:
raise Exception('password should not be none')
if initial:
roles = [Admin_Role]
cloud_auth_manager.refresh()
payload = {
"initial": initial,
"username": username,
"password": password,
"roles": roles,
"creator": creator,
}
raw_resp = requests.post(f'{cloud_auth_manager.api_url}users',
json=payload,
headers=self._get_headers_by_user(creator)
)
resp = raw_resp.json()
if raw_resp.status_code != 201:
raise Exception(resp['message'])
cloud_auth_manager.update_gradio_auth()
return True
def delete_user(self, username, user_token=""):
if not self.auth_manger.enableAuth:
return {}
if username == cloud_auth_manager.username:
raise Exception('Cannot delete current user')
payload = {
"user_name_list": [username]
}
raw_resp = requests.delete(f'{cloud_auth_manager.api_url}users',
json=payload,
headers=self._get_headers_by_user(user_token))
if raw_resp.status_code != 204:
raise Exception(raw_resp.json()['message'])
return True
def list_models_on_cloud(self, username, types='Stable-diffusion', status='Active'):
if not self.auth_manger.enableAuth:
return []
params = {
'username': username,
'types': types,
'status': status,
'per_page': 100,
}
headers = self._get_headers_by_user(username)
raw_resp = api.list_checkpoints(params=params, headers=headers)
if raw_resp.status_code != 200:
logger.error(f"list_checkpoints: {raw_resp.json()}")
return []
checkpoints = []
if 'data' not in raw_resp.json():
return checkpoints
resp = raw_resp.json()['data']
for ckpt in resp['checkpoints']:
if not ckpt or 'name' not in ckpt or not ckpt['name']:
continue
for name in ckpt['name']:
if name not in checkpoints:
checkpoints.append({
'name': name,
'id': ckpt['id'],
's3Location': ckpt['s3Location'],
'type': ckpt['type'],
'status': ckpt['status'],
'created': float(ckpt['created']),
'allowed_roles_or_users': ckpt['allowed_roles_or_users'],
})
return checkpoints
def list_all_inference_jobs_on_cloud(self, target_task_type, username, first_load="first"):
if not self.auth_manger.enableAuth:
return []
params = {
'username': username,
'type': target_task_type,
'limit': 10,
}
global last_evaluated_key
last_key_previous = f"{username}_{target_task_type}_previous"
if last_key_previous not in last_evaluated_key:
last_evaluated_key[last_key_previous] = []
last_key_cur = f"{username}_{target_task_type}_cur"
if last_key_cur not in last_evaluated_key:
last_evaluated_key[last_key_cur] = []
last_key_cur_key = f"{username}_{target_task_type}_cur_key"
if last_key_cur_key not in last_evaluated_key:
last_evaluated_key[last_key_cur_key] = None
last_key_next = f"{username}_{target_task_type}_next"
if last_key_next not in last_evaluated_key:
last_evaluated_key[last_key_next] = None
if first_load == "next":
if last_evaluated_key[last_key_next]:
last_evaluated_key[last_key_previous].append(last_evaluated_key[last_key_next])
params['exclusive_start_key'] = last_evaluated_key[last_key_next]
else:
return last_evaluated_key[last_key_cur]
elif first_load == "previous":
if len(last_evaluated_key[last_key_previous]) > 0:
pre_key = last_evaluated_key[last_key_previous].pop()
if pre_key != last_evaluated_key[last_key_cur_key]:
params['exclusive_start_key'] = pre_key
elif len(last_evaluated_key[last_key_previous]) > 0:
params['exclusive_start_key'] = last_evaluated_key[last_key_previous].pop()
else:
last_evaluated_key[last_key_next] = None
last_evaluated_key[last_key_previous] = []
if 'exclusive_start_key' in params:
last_evaluated_key[last_key_cur_key] = params['exclusive_start_key']
raw_resp = requests.get(url=f'{self.auth_manger.api_url}inferences', params=params,
headers=self._get_headers_by_user(username))
if raw_resp.status_code != 200:
logger.error(f"list_inferences: {raw_resp.json()}")
return []
resp = raw_resp.json()
if 'last_evaluated_key' in resp['data']:
last_evaluated_key[last_key_next] = resp['data']['last_evaluated_key']
else:
last_evaluated_key[last_key_next] = None
last_evaluated_key[last_key_cur] = resp['data']['inferences']
return resp['data']['inferences']
def get_dataset_items_from_dataset(self, dataset_name, user_token=""):
if not self.auth_manger.enableAuth:
return []
raw_response = requests.get(url=f'{self.auth_manger.api_url}datasets/{dataset_name}',
headers=self._get_headers_by_user(user_token))
raw_response.raise_for_status()
return raw_response.json()['data']
api_manager = CloudApiManager()