chore: move train to extension
parent
87920e570c
commit
0108f2ff3a
|
|
@ -0,0 +1,40 @@
|
|||
import base64
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import requests
|
||||
|
||||
from utils import get_variable_from_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def get_sorted_cloud_dataset(username):
|
||||
url = get_variable_from_json("api_gateway_url") + "datasets?dataset_status=Enabled"
|
||||
api_key = get_variable_from_json("api_token")
|
||||
if not url or not api_key:
|
||||
logger.debug("Url or API-Key is not setting.")
|
||||
return []
|
||||
|
||||
try:
|
||||
encode_type = "utf-8"
|
||||
raw_response = requests.get(
|
||||
url=url,
|
||||
headers={
|
||||
"x-api-key": api_key,
|
||||
"Authorization": f"Bearer {base64.b16encode(username.encode(encode_type)).decode(encode_type)}",
|
||||
},
|
||||
)
|
||||
raw_response.raise_for_status()
|
||||
response = raw_response.json()
|
||||
logger.info(f"datasets response: {response}")
|
||||
datasets = response["data"]["datasets"]
|
||||
datasets.sort(
|
||||
key=lambda t: t["timestamp"] if "timestamp" in t else sys.float_info.max,
|
||||
reverse=True,
|
||||
)
|
||||
return datasets
|
||||
except Exception as e:
|
||||
logger.error(f"exception {e}")
|
||||
return []
|
||||
|
|
@ -16,7 +16,7 @@ from aws_extension.cloud_api_manager.api import api
|
|||
from aws_extension.cloud_api_manager.api_manager import api_manager
|
||||
from aws_extension.sagemaker_ui import checkpoint_type
|
||||
from aws_extension.sagemaker_ui_utils import create_refresh_button_by_user
|
||||
from dreambooth_on_cloud.train import get_sorted_cloud_dataset
|
||||
from aws_extension.cloud_dataset_manager.dataset_manager import get_sorted_cloud_dataset
|
||||
from utils import get_variable_from_json, save_variable_to_json, has_config, is_gcr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -1,34 +0,0 @@
|
|||
import base64
|
||||
import logging
|
||||
import sys
|
||||
import requests
|
||||
|
||||
from utils import get_variable_from_json
|
||||
|
||||
logging.basicConfig(filename='sd-aws-ext.log', level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def get_sorted_cloud_dataset(username):
|
||||
url = get_variable_from_json('api_gateway_url') + 'datasets?dataset_status=Enabled'
|
||||
api_key = get_variable_from_json('api_token')
|
||||
if not url or not api_key:
|
||||
logger.debug("Url or API-Key is not setting.")
|
||||
return []
|
||||
|
||||
try:
|
||||
encode_type = "utf-8"
|
||||
raw_response = requests.get(url=url, headers={
|
||||
'x-api-key': api_key,
|
||||
'Authorization': f'Bearer {base64.b16encode(username.encode(encode_type)).decode(encode_type)}',
|
||||
})
|
||||
raw_response.raise_for_status()
|
||||
response = raw_response.json()
|
||||
logger.info(f"datasets response: {response}")
|
||||
datasets = response['data']['datasets']
|
||||
datasets.sort(key=lambda t: t['timestamp'] if 'timestamp' in t else sys.float_info.max, reverse=True)
|
||||
return datasets
|
||||
except Exception as e:
|
||||
logger.error(f"exception {e}")
|
||||
return []
|
||||
Loading…
Reference in New Issue