chore: move train to extension

pull/467/head
Ning 2024-01-25 13:29:30 +08:00
parent 87920e570c
commit 0108f2ff3a
4 changed files with 41 additions and 35 deletions

View File

@ -0,0 +1,40 @@
import base64
import logging
import sys
import requests
from utils import get_variable_from_json
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def get_sorted_cloud_dataset(username):
url = get_variable_from_json("api_gateway_url") + "datasets?dataset_status=Enabled"
api_key = get_variable_from_json("api_token")
if not url or not api_key:
logger.debug("Url or API-Key is not setting.")
return []
try:
encode_type = "utf-8"
raw_response = requests.get(
url=url,
headers={
"x-api-key": api_key,
"Authorization": f"Bearer {base64.b16encode(username.encode(encode_type)).decode(encode_type)}",
},
)
raw_response.raise_for_status()
response = raw_response.json()
logger.info(f"datasets response: {response}")
datasets = response["data"]["datasets"]
datasets.sort(
key=lambda t: t["timestamp"] if "timestamp" in t else sys.float_info.max,
reverse=True,
)
return datasets
except Exception as e:
logger.error(f"exception {e}")
return []

View File

@ -16,7 +16,7 @@ from aws_extension.cloud_api_manager.api import api
from aws_extension.cloud_api_manager.api_manager import api_manager
from aws_extension.sagemaker_ui import checkpoint_type
from aws_extension.sagemaker_ui_utils import create_refresh_button_by_user
from dreambooth_on_cloud.train import get_sorted_cloud_dataset
from aws_extension.cloud_dataset_manager.dataset_manager import get_sorted_cloud_dataset
from utils import get_variable_from_json, save_variable_to_json, has_config, is_gcr
logger = logging.getLogger(__name__)

View File

@ -1,34 +0,0 @@
import base64
import logging
import sys
import requests
from utils import get_variable_from_json
logging.basicConfig(filename='sd-aws-ext.log', level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def get_sorted_cloud_dataset(username):
url = get_variable_from_json('api_gateway_url') + 'datasets?dataset_status=Enabled'
api_key = get_variable_from_json('api_token')
if not url or not api_key:
logger.debug("Url or API-Key is not setting.")
return []
try:
encode_type = "utf-8"
raw_response = requests.get(url=url, headers={
'x-api-key': api_key,
'Authorization': f'Bearer {base64.b16encode(username.encode(encode_type)).decode(encode_type)}',
})
raw_response.raise_for_status()
response = raw_response.json()
logger.info(f"datasets response: {response}")
datasets = response['data']['datasets']
datasets.sort(key=lambda t: t['timestamp'] if 'timestamp' in t else sys.float_info.max, reverse=True)
return datasets
except Exception as e:
logger.error(f"exception {e}")
return []