diff --git a/aws_extension/cloud_dataset_manager/__init__.py b/aws_extension/cloud_dataset_manager/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aws_extension/cloud_dataset_manager/dataset_manager.py b/aws_extension/cloud_dataset_manager/dataset_manager.py new file mode 100644 index 00000000..c637a7c9 --- /dev/null +++ b/aws_extension/cloud_dataset_manager/dataset_manager.py @@ -0,0 +1,40 @@ +import base64 +import logging +import sys + +import requests + +from utils import get_variable_from_json + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def get_sorted_cloud_dataset(username): + url = get_variable_from_json("api_gateway_url") + "datasets?dataset_status=Enabled" + api_key = get_variable_from_json("api_token") + if not url or not api_key: + logger.debug("Url or API-Key is not setting.") + return [] + + try: + encode_type = "utf-8" + raw_response = requests.get( + url=url, + headers={ + "x-api-key": api_key, + "Authorization": f"Bearer {base64.b16encode(username.encode(encode_type)).decode(encode_type)}", + }, + ) + raw_response.raise_for_status() + response = raw_response.json() + logger.info(f"datasets response: {response}") + datasets = response["data"]["datasets"] + datasets.sort( + key=lambda t: t["timestamp"] if "timestamp" in t else sys.float_info.max, + reverse=True, + ) + return datasets + except Exception as e: + logger.error(f"exception {e}") + return [] diff --git a/aws_extension/sagemaker_ui_tab.py b/aws_extension/sagemaker_ui_tab.py index 88f245e4..d5ea69b0 100644 --- a/aws_extension/sagemaker_ui_tab.py +++ b/aws_extension/sagemaker_ui_tab.py @@ -16,7 +16,7 @@ from aws_extension.cloud_api_manager.api import api from aws_extension.cloud_api_manager.api_manager import api_manager from aws_extension.sagemaker_ui import checkpoint_type from aws_extension.sagemaker_ui_utils import create_refresh_button_by_user -from dreambooth_on_cloud.train import get_sorted_cloud_dataset +from aws_extension.cloud_dataset_manager.dataset_manager import get_sorted_cloud_dataset from utils import get_variable_from_json, save_variable_to_json, has_config, is_gcr logger = logging.getLogger(__name__) diff --git a/dreambooth_on_cloud/train.py b/dreambooth_on_cloud/train.py deleted file mode 100644 index 50757e85..00000000 --- a/dreambooth_on_cloud/train.py +++ /dev/null @@ -1,34 +0,0 @@ -import base64 -import logging -import sys -import requests - -from utils import get_variable_from_json - -logging.basicConfig(filename='sd-aws-ext.log', level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -def get_sorted_cloud_dataset(username): - url = get_variable_from_json('api_gateway_url') + 'datasets?dataset_status=Enabled' - api_key = get_variable_from_json('api_token') - if not url or not api_key: - logger.debug("Url or API-Key is not setting.") - return [] - - try: - encode_type = "utf-8" - raw_response = requests.get(url=url, headers={ - 'x-api-key': api_key, - 'Authorization': f'Bearer {base64.b16encode(username.encode(encode_type)).decode(encode_type)}', - }) - raw_response.raise_for_status() - response = raw_response.json() - logger.info(f"datasets response: {response}") - datasets = response['data']['datasets'] - datasets.sort(key=lambda t: t['timestamp'] if 'timestamp' in t else sys.float_info.max, reverse=True) - return datasets - except Exception as e: - logger.error(f"exception {e}") - return []