98 lines
2.9 KiB
Python
98 lines
2.9 KiB
Python
from __future__ import print_function
|
|
|
|
import logging
|
|
import time
|
|
from datetime import datetime
|
|
from datetime import timedelta
|
|
|
|
import pytest
|
|
|
|
import config as config
|
|
from utils.api import Api
|
|
from utils.enums import InferenceStatus, InferenceType
|
|
from utils.helper import upload_with_put, get_inference_job_status, \
|
|
delete_inference_jobs
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
filename = "v1-5-pruned-emaonly.safetensors"
|
|
api_params_filename = "./data/api_params/xyz_checkpoint_api_param.json"
|
|
inference_data = {}
|
|
|
|
|
|
class TestXyzCheckpointE2E:
|
|
|
|
def setup_class(self):
|
|
self.api = Api(config)
|
|
self.api.feat_oas_schema()
|
|
|
|
@classmethod
|
|
def teardown_class(self):
|
|
pass
|
|
|
|
global inference_data
|
|
if 'id' in inference_data:
|
|
delete_inference_jobs([inference_data['id']])
|
|
|
|
@pytest.mark.skip(reason="not ready")
|
|
def test_1_xyz_checkpoint_txt2img_job_create(self):
|
|
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username
|
|
}
|
|
|
|
data = {
|
|
"inference_type": "Async",
|
|
"task_type": InferenceType.TXT2IMG.value,
|
|
"models": {
|
|
"Stable-diffusion": [filename],
|
|
"embeddings": []
|
|
},
|
|
"filters": {}
|
|
}
|
|
|
|
resp = self.api.create_inference(headers=headers, data=data)
|
|
assert resp.status_code == 201, resp.dumps()
|
|
global inference_data
|
|
inference_data = resp.json()['data']["inference"]
|
|
|
|
assert resp.json()["statusCode"] == 201
|
|
assert inference_data["type"] == InferenceType.TXT2IMG.value
|
|
assert len(inference_data["api_params_s3_upload_url"]) > 0
|
|
|
|
upload_with_put(inference_data["api_params_s3_upload_url"], api_params_filename)
|
|
|
|
@pytest.mark.skip(reason="not ready")
|
|
def test_2_xyz_checkpoint_txt2img_job_succeed(self):
|
|
|
|
global inference_data
|
|
assert inference_data["type"] == InferenceType.TXT2IMG.value
|
|
|
|
inference_id = inference_data["id"]
|
|
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username
|
|
}
|
|
|
|
resp = self.api.start_inference_job(job_id=inference_id, headers=headers)
|
|
assert resp.status_code == 202, resp.dumps()
|
|
assert resp.json()['data']["inference"]["status"] == InferenceStatus.INPROGRESS.value
|
|
|
|
timeout = datetime.now() + timedelta(minutes=2)
|
|
|
|
while datetime.now() < timeout:
|
|
status = get_inference_job_status(
|
|
api_instance=self.api,
|
|
job_id=inference_id
|
|
)
|
|
logger.info(f"xyz inference is {status}")
|
|
if status == InferenceStatus.SUCCEED.value:
|
|
break
|
|
if status == InferenceStatus.FAILED.value:
|
|
raise Exception("Inference job failed.")
|
|
time.sleep(5)
|
|
else:
|
|
raise Exception("Inference timed out after 2 minutes.")
|