181 lines
5.6 KiB
Python
181 lines
5.6 KiB
Python
from __future__ import print_function
|
|
|
|
import logging
|
|
import os
|
|
import threading
|
|
import time
|
|
import uuid
|
|
|
|
import pytest
|
|
|
|
import config as config
|
|
from utils.api import Api
|
|
from utils.helper import wget_file, comfy_execute_create, get_endpoint_comfy_async, get_endpoint_sd_async, \
|
|
sd_inference_create
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username
|
|
}
|
|
|
|
id = str(uuid.uuid4())
|
|
|
|
|
|
@pytest.mark.skipif(not config.is_local, reason="local test only")
|
|
class TestLatencyCompareTasks:
|
|
def setup_class(self):
|
|
self.api = Api(config)
|
|
self.api.feat_oas_schema()
|
|
self.endpoint_name = get_endpoint_comfy_async(self.api)
|
|
self.endpoint_name_sd = get_endpoint_sd_async(self.api)
|
|
|
|
@classmethod
|
|
def teardown_class(self):
|
|
pass
|
|
|
|
def test_1_download_file(self):
|
|
local_path = f"./data/comfy/models/checkpoints/v1-5-pruned-emaonly.ckpt"
|
|
wget_file(
|
|
local_path,
|
|
'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt'
|
|
)
|
|
|
|
def test_2_sync_files_to_comfy_endpoint(self):
|
|
local = "'./data/comfy/models/*'"
|
|
target = f"'s3://{config.bucket}/comfy/{self.endpoint_name}/{id}/models/'"
|
|
logger.info(f"Syncing {local} to {target}")
|
|
os.system(f"rm -rf ./s5cmd")
|
|
os.system(f"wget -q ./ https://raw.githubusercontent.com/elonniu/s5cmd/main/s5cmd")
|
|
os.system(f"chmod +x ./s5cmd")
|
|
os.system(f"./s5cmd sync {local} {target}")
|
|
|
|
def test_3_comfy_sync_files(self):
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username
|
|
}
|
|
data = {"endpoint_name": f"{self.endpoint_name}",
|
|
"need_reboot": True,
|
|
"prepare_id": id,
|
|
"prepare_type": "models"}
|
|
resp = self.api.prepare(data=data, headers=headers)
|
|
assert resp.status_code == 200, resp.dumps()
|
|
logger.info(resp.json())
|
|
logger.info(f"wait 20s for endpoint sync files...")
|
|
time.sleep(20)
|
|
|
|
def test_4_clean_all_executes(self):
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username
|
|
}
|
|
|
|
while True:
|
|
resp = self.api.list_executes(headers=headers, params={"limit": 20})
|
|
executes = resp.json()['data']['executes']
|
|
if len(executes) == 0:
|
|
break
|
|
|
|
execute_id_list = []
|
|
i = 0
|
|
for execute in executes:
|
|
i = i + 1
|
|
prompt_id = execute['prompt_id']
|
|
execute_id_list.append(prompt_id)
|
|
logger.info(f"delete execute {i} {prompt_id}")
|
|
|
|
data = {
|
|
"execute_id_list": execute_id_list,
|
|
}
|
|
resp = self.api.delete_executes(headers=headers, data=data)
|
|
if resp.status_code == 400:
|
|
logger.info(resp.json()['message'])
|
|
time.sleep(5)
|
|
continue
|
|
|
|
def test_6_clean_all_inferences(self):
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username
|
|
}
|
|
|
|
while True:
|
|
|
|
resp = self.api.list_inferences(headers=headers, params={"limit": 20})
|
|
inferences = resp.json()['data']['inferences']
|
|
if len(inferences) == 0:
|
|
break
|
|
|
|
ids = []
|
|
i = 0
|
|
for inference in inferences:
|
|
i = i + 1
|
|
inference_id = inference['InferenceJobId']
|
|
ids.append(inference_id)
|
|
logger.info(f"delete execute {i} {inference_id}")
|
|
|
|
data = {
|
|
"inference_id_list": ids,
|
|
}
|
|
resp = self.api.delete_inferences(headers=headers, data=data)
|
|
if resp.status_code == 400:
|
|
logger.info(resp.json()['message'])
|
|
time.sleep(5)
|
|
continue
|
|
|
|
def test_7_update_api_roles(self):
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
data = {
|
|
"username": "api",
|
|
"password": "admin",
|
|
"creator": "api",
|
|
"roles": [
|
|
'IT Operator',
|
|
'byoc',
|
|
config.role_sd_real_time,
|
|
config.role_sd_async,
|
|
config.role_comfy_async,
|
|
config.role_comfy_real_time,
|
|
],
|
|
}
|
|
|
|
resp = self.api.create_user(headers=headers, data=data)
|
|
|
|
assert resp.status_code == 201, resp.dumps()
|
|
assert resp.json()["statusCode"] == 201
|
|
|
|
def create_batch_executes(self, n, api, endpoint_name):
|
|
for i in range(n):
|
|
comfy_execute_create(n=i, api=api, endpoint_name=endpoint_name, wait_succeed=True,
|
|
workflow='./data/api_params/latency-comfy.json')
|
|
|
|
def create_batch_inferences(self, n, api, endpoint_name):
|
|
for i in range(n):
|
|
sd_inference_create(n=i, api=api, endpoint_name=endpoint_name, workflow='./data/api_params/latency-sd.json')
|
|
|
|
def test_8_latency_compare_start(self):
|
|
self.test_7_update_api_roles()
|
|
|
|
threads = []
|
|
|
|
batch = 1000
|
|
|
|
thread = threading.Thread(target=self.create_batch_executes, args=(batch, self.api, self.endpoint_name))
|
|
threads.append(thread)
|
|
|
|
thread = threading.Thread(target=self.create_batch_inferences, args=(batch, self.api, self.endpoint_name_sd))
|
|
threads.append(thread)
|
|
|
|
for thread in threads:
|
|
thread.start()
|
|
|
|
for thread in threads:
|
|
thread.join()
|