import json import logging import requests from modules import shared from aws_extension.sagemaker_ui_utils import warning from utils import host_url, api_key logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) client_api_version = "1.5.0" def upgrade_info(resp): if 'x-api-version' not in resp.headers: warning(f"Client api version {client_api_version} is not compatible api version. " f"Please update the client or api.") return shared.demo.server_app.api_version = api_version = resp.headers['x-api-version'] if api_version < client_api_version: warning(f"extension version {client_api_version} is not compatible api version {api_version}. " f"Please update the api.") return if api_version > client_api_version: warning(f"extension version {client_api_version} is not compatible api version {api_version}. " f"Please update the extension.") return class Api: username = None def set_username(self, username): self.username = username return self.username def __init__(self, debug: bool = True): self.debug = debug def req(self, method: str, path: str, headers=None, data=None, params=None): url = f"{host_url()}{path}" if data is not None: data = json.dumps(data) if headers is None: headers = { 'x-api-key': api_key(), 'Content-Type': 'application/json', } else: headers['x-api-key'] = api_key() headers['Content-Type'] = 'application/json' if self.username: headers['username'] = self.username if self.debug: logger.info(f"{method} {url}") if headers: logger.info(f"headers: {headers}") if data: logger.info(f"data: {data}") if params: logger.info(f"params: {params}") resp = requests.request( method=method, url=url, headers=headers, data=data, params=params, timeout=(20, 30) ) upgrade_info(resp) if self.debug: logger.info(f"resp headers: {resp.headers}") logger.info(f"{resp.status_code} {resp.text}") return resp def ping(self, headers=None): return self.req( "GET", "ping", headers=headers ) def list_roles(self, headers=None, params=None): return self.req( "GET", "roles", headers=headers, params=params ) def delete_roles(self, headers=None, data=None): return self.req( "DELETE", "roles", headers=headers, data=data ) def delete_datasets(self, headers=None, data=None): return self.req( "DELETE", "datasets", headers=headers, data=data ) def delete_models(self, headers=None, data=None): return self.req( "DELETE", "models", headers=headers, data=data ) def delete_trainings(self, headers=None, data=None): return self.req( "DELETE", "trainings", headers=headers, data=data ) def delete_inferences(self, headers=None, data=None): return self.req( "DELETE", "inferences", headers=headers, data=data ) def delete_checkpoints(self, headers=None, data=None): return self.req( "DELETE", "checkpoints", headers=headers, data=data ) def create_role(self, headers=None, data=None): return self.req( "POST", "roles", headers=headers, data=data ) def list_users(self, headers=None, params=None): return self.req( "GET", "users", headers=headers, params=params ) def delete_users(self, headers=None, data=None): return self.req( "DELETE", f"users", headers=headers, data=data ) def create_user(self, headers=None, data=None): return self.req( "POST", "users", headers=headers, data=data ) def list_checkpoints(self, headers=None, params=None): return self.req( "GET", "checkpoints", headers=headers, params=params ) def create_checkpoint(self, headers=None, data=None): return self.req( "POST", "checkpoints", headers=headers, data=data ) def update_checkpoint(self, checkpoint_id: str, headers=None, data=None): return self.req( "PUT", f"checkpoints/{checkpoint_id}", headers=headers, data=data ) def delete_endpoints(self, headers=None, data=None): return self.req( "DELETE", "endpoints", headers=headers, data=data ) def list_endpoints(self, headers=None, params=None): return self.req( "GET", "endpoints", headers=headers, params=params ) def create_endpoint(self, headers=None, data=None): return self.req( "POST", "endpoints", headers=headers, data=data ) def create_inference(self, headers=None, data=None): return self.req( "POST", "inferences", headers=headers, data=data ) def start_inference_job(self, job_id: str, headers=None): return self.req( "PUT", f"inferences/{job_id}/start", headers=headers, ) def get_training_job(self, job_id: str, headers=None): return self.req( "GET", f"trainings/{job_id}", headers=headers, ) def get_inference_job(self, job_id: str, headers=None): return self.req( "GET", f"inferences/{job_id}", headers=headers ) def list_datasets(self, headers=None, params=None): return self.req( "GET", "datasets", headers=headers, params=params ) def get_dataset(self, name: str, headers=None): return self.req( "GET", f"datasets/{name}", headers=headers ) def create_dataset(self, headers=None, data=None): return self.req( "POST", "datasets", headers=headers, data=data ) def update_dataset(self, dataset_id: str, headers=None, data=None): return self.req( "PUT", f"datasets/{dataset_id}", headers=headers, data=data ) def crop_dataset(self, dataset_name: str, headers=None, data=None): return self.req( "POST", f"datasets/{dataset_name}/crop", headers=headers, data=data ) def create_model(self, headers=None, data=None): return self.req( "POST", "models", headers=headers, data=data ) def update_model(self, model_id: str, headers=None, data=None): return self.req( "PUT", f"models/{model_id}", headers=headers, data=data ) def list_models(self, headers=None, params=None): return self.req( "GET", "models", headers=headers, params=params ) def start_training_job(self, training_id: str, headers=None, data=None): return self.req( "PUT", f"trainings/{training_id}/start", headers=headers, data=data ) def create_training_job(self, headers=None, data=None): return self.req( "POST", "trainings", headers=headers, data=data ) def list_trainings(self, headers=None, params=None): return self.req( "GET", "trainings", headers=headers, params=params ) def list_inferences(self, headers=None, params=None): return self.req( "GET", "inferences", headers=headers, params=params ) api = Api()