stable-diffusion-aws-extension/test/config.py

119 lines
3.6 KiB
Python

import logging
import os
from datetime import datetime
from dotenv import load_dotenv
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
load_dotenv()
host_url = os.environ.get("API_GATEWAY_URL")
if not host_url:
raise Exception("API_GATEWAY_URL is empty")
region_name = host_url.split('.')[2]
if not region_name:
raise Exception("API_GATEWAY_URL is invalid")
# Remove "/prod" or "/prod/" from the end of the host_url
host_url = host_url.replace("/prod/", "")
host_url = host_url.replace("/prod", "")
if host_url.endswith("/"):
host_url = host_url[:-1]
logger.info(f"config.host_url: {host_url}")
api_key = os.environ.get("API_GATEWAY_URL_TOKEN")
if not api_key:
raise Exception("API_GATEWAY_URL_TOKEN is empty")
logger.info(f"config.api_key: {api_key}")
username = "api"
logger.info(f"config.username: {username}")
bucket = os.environ.get("API_BUCKET")
if not bucket:
raise Exception("API_BUCKET is empty")
logger.info(f"config.bucket: {bucket}")
test_fast = os.environ.get("TEST_FAST") == "true"
logger.info(f"config.test_fast: {test_fast}")
is_gcr = region_name.startswith("cn-")
logger.info(f"config.is_gcr: {is_gcr}")
is_local = os.environ.get("SNS_ARN") is None
logger.info(f"config.is_local: {is_local}")
role_name = "role_name"
logger.info(f"config.role_name: {role_name}")
endpoint_name = datetime.utcnow().strftime("%m%d%H%M%S")
logger.info(f"config.endpoint_name: {endpoint_name}")
dataset_name = "dataset_name"
logger.info(f"config.dataset_name: {dataset_name}")
train_model_name = "train_model_name"
logger.info(f"config.train_model_name: {train_model_name}")
train_wd14_model_name = "wd14_model_name"
logger.info(f"config.train_wd14_model_name: {train_wd14_model_name}")
model_name = "test-model"
logger.info(f"config.model_name: {model_name}")
async_instance_type = os.environ.get("ASYNC_INSTANCE_TYPE", "ml.g5.2xlarge")
if is_gcr:
async_instance_type = "ml.g4dn.2xlarge"
# special case for sometimes
if region_name == "us-east-1":
async_instance_type = "ml.g4dn.2xlarge"
if region_name == "ap-southeast-1":
async_instance_type = "ml.g5.2xlarge"
if region_name == "us-west-2":
async_instance_type = "ml.g6.2xlarge"
logger.info(f"config.async_instance_type: {async_instance_type}")
real_time_instance_type = os.environ.get("REAL_TIME_INSTANCE_TYPE", "ml.g5.2xlarge")
if is_gcr:
real_time_instance_type = "ml.g4dn.4xlarge"
logger.info(f"config.real_time_instance_type: {real_time_instance_type}")
initial_instance_count = "2"
if is_gcr:
initial_instance_count = "1"
logger.info(f"config.initial_instance_count: {initial_instance_count}")
default_model_id = "v1-5-pruned-emaonly.safetensors"
logger.info(f"config.default_model_id: {default_model_id}")
ckpt_message = "placeholder for chkpts upload test"
logger.info(f"config.ckpt_message: {ckpt_message}")
train_instance_type = os.environ.get("TRAIN_INSTANCE_TYPE", "ml.g5.2xlarge")
if region_name == "ap-southeast-1":
train_instance_type = "ml.g4dn.2xlarge"
if is_gcr:
train_instance_type = "ml.g4dn.2xlarge"
logger.info(f"config.train_instance_type: {train_instance_type}")
comfy_async_ep_name = f"comfy-async-{endpoint_name}"
comfy_real_time_ep_name = f"comfy-real-time-{endpoint_name}"
compare_content = os.environ.get("COMPARE_CONTENT", "true")
logger.info(f"config.compare_content: {compare_content}")
webui_stack = "webui-stack"
comfy_stack = "comfy-stack"
role_sd_async = "sd_async"
role_sd_real_time = "sd_real_time"
role_comfy_async = "comfy_async"
role_comfy_real_time = "comfy_real_time"
custom_docker_image_uri = os.getenv("CUSTOM_DOCKER_IMAGE_URI", None)
logger.info(f"config.custom_docker_image_uri: {custom_docker_image_uri}")