stable-diffusion-aws-extension/test/utils/helper.py

541 lines
17 KiB
Python
Raw Blame History

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

import base64
import decimal
import io
import json
import logging
import math
import os
import subprocess
import sys
import tarfile
import time
import uuid
from datetime import datetime, timedelta
import boto3
import requests
from botocore.exceptions import NoCredentialsError, PartialCredentialsError
import config as config
from utils.api import Api
from utils.enums import InferenceStatus, InferenceType
logger = logging.getLogger(__name__)
s3 = boto3.client('s3')
def get_parts_number(local_path: str):
file_size = os.stat(local_path).st_size
part_size = 1000 * 1024 * 1024
return math.ceil(file_size / part_size)
def wget_file(local_file: str, url: str, gcr_url: str = None):
# if gcr_url is not None and config.is_gcr:
# url = gcr_url
if not os.path.exists(local_file):
local_path = os.path.dirname(local_file)
logger.info(f"Downloading {url}")
wget_process = subprocess.run(['wget', '-qP', local_path, url], capture_output=True)
logger.info(f"Downloaded {url}")
if wget_process.returncode != 0:
raise subprocess.CalledProcessError(wget_process.returncode, 'wget failed')
def create_tar(json_string: str, path: str):
with io.BytesIO() as tar_buffer:
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
json_bytes = json_string.encode("utf-8")
json_buffer = io.BytesIO(json_bytes)
tarinfo = tarfile.TarInfo(name=path)
tarinfo.size = len(json_bytes)
tar.addfile(tarinfo, json_buffer)
return tar_buffer.getvalue()
def list_endpoints(api_instance):
headers = {
"x-api-key": config.api_key,
"username": config.username
}
resp = api_instance.list_endpoints(headers=headers)
assert resp.status_code == 200, resp.dumps()
endpoints = resp.json()['data']["endpoints"]
return endpoints
def get_endpoint_status(api_instance, endpoint_name: str):
endpoints = list_endpoints(api_instance)
for endpoint in endpoints:
if endpoint['endpoint_name'] == endpoint_name:
return endpoint['endpoint_status']
return None
def get_inference_job_status(api_instance, job_id):
headers = {
"x-api-key": config.api_key,
"username": config.username
}
resp = api_instance.get_inference_job(job_id=job_id, headers=headers)
if InferenceStatus.FAILED.value == resp.json()['data']['status']:
logger.error(f"Failed inference: {resp.json()['data']}")
return resp.json()['data']['status']
def get_inference_image(api_instance, job_id: str, target_file: str):
resp = api_instance.get_inference_job(
job_id=job_id,
headers={
"x-api-key": config.api_key,
"username": config.username
},
)
if 'data' not in resp.json():
raise Exception(f"data not found in inference job: {resp.json()}")
if 'img_presigned_urls' not in resp.json()['data']:
raise Exception(f"img_presigned_urls not found in inference job: {resp.json()}")
if config.compare_content == 'false':
logger.info(f"compare_content is false, skip comparing image {target_file}")
return
img_presigned_urls = resp.json()['data']['img_presigned_urls']
for img_url in img_presigned_urls:
resp = requests.get(img_url)
with open(f"{target_file}", "wb") as f:
f.write(resp.content)
logger.info(f"Image {target_file} saved")
def get_inference_job_image(api_instance, job_id: str, target_file: str):
resp = api_instance.get_inference_job(
job_id=job_id,
headers={
"x-api-key": config.api_key,
"username": config.username
},
)
if 'data' not in resp.json():
raise Exception(f"data not found in inference job: {resp.json()}")
if 'img_presigned_urls' not in resp.json()['data']:
raise Exception(f"img_presigned_urls not found in inference job: {resp.json()}")
if config.compare_content == 'false':
logger.info(f"compare_content is false, skip comparing image {target_file}")
return
img_presigned_urls = resp.json()['data']['img_presigned_urls']
for img_url in img_presigned_urls:
resp = requests.get(img_url)
if not os.path.exists(target_file):
with open(f"{target_file}", "wb") as f:
f.write(resp.content)
raise Exception(f"Image {target_file} first generated")
if resp.content == open(target_file, "rb").read():
return
# write image to file
with open(f"{target_file}.png", "wb") as f:
f.write(resp.content)
logger.info(f"Image {target_file} not same with {target_file}.png")
return
# raise Exception(f"Image {target_file} different with {target_file}.png")
raise Exception(f"Image not found in inference job: {resp.json()}")
def delete_sagemaker_endpoint(api_instance, endpoint_name: str):
headers = {
"x-api-key": config.api_key,
"username": config.username
}
data = {
"endpoint_name_list": [
endpoint_name,
],
"username": config.username
}
resp = api_instance.delete_endpoints(headers=headers, data=data)
assert resp.status_code == 204, resp.dumps()
def delete_inference_jobs(inference_id_list: [str]):
api = Api(config)
data = {
"inference_id_list": inference_id_list,
}
api.delete_inferences(data=data, headers={"x-api-key": config.api_key, })
def upload_with_put(s3_url, local_file):
with open(local_file, 'rb') as data:
response = requests.put(s3_url, data=data)
response.raise_for_status()
def upload_multipart_file(signed_urls, local_path):
logger.info(f"Uploading {local_path}")
with open(local_path, "rb") as f:
parts = []
for i, signed_url in enumerate(signed_urls):
part_size = 1000 * 1024 * 1024
file_data = f.read(part_size)
response = requests.put(signed_url, data=file_data)
response.raise_for_status()
etag = response.headers['ETag']
parts.append({
'ETag': etag,
'PartNumber': i + 1
})
print(f'model upload part {i + 1}: {response}')
return parts
# s_tmax: Infinity
def parse_constant(c: str) -> float:
if c == "NaN":
raise ValueError("NaN is not valid JSON")
if c == 'Infinity':
return sys.float_info.max
return float(c)
class DecimalEncoder(json.JSONEncoder):
def default(self, obj):
# if passed in an object is instance of Decimal
# convert it to a string
if isinstance(obj, decimal.Decimal):
return str(obj)
# otherwise use the default behavior
return json.JSONEncoder.default(self, obj)
def comfy_execute_create(n, api, endpoint_name, wait_succeed=True,
workflow: str = './data/api_params/comfy_workflow.json'):
with open(workflow, 'r') as f:
headers = {
"x-api-key": config.api_key,
}
prompt_id = str(uuid.uuid4())
workflow = json.load(f)
workflow['prompt_id'] = prompt_id
workflow['workflow'] = 'latency_compare_comfy'
workflow['endpoint_name'] = endpoint_name
resp = api.create_execute(headers=headers, data=workflow)
assert resp.status_code in [200, 201], resp.dumps()
assert resp.json()['data']['prompt_id'] == prompt_id, resp.dumps()
if not wait_succeed:
return
timeout = datetime.now() + timedelta(minutes=5)
init_status = ''
while datetime.now() < timeout:
time.sleep(1)
resp = api.get_execute_job(headers=headers, prompt_id=prompt_id)
if resp.status_code == 404:
init_status = "not found"
logger.info(f"comfy {n} {endpoint_name} {prompt_id} is {init_status}")
continue
assert resp.status_code == 200, resp.dumps()
assert 'status' in resp.json()['data'], resp.dumps()
status = resp.json()["data"]["status"]
if init_status != status:
logger.info(f"comfy {n} {endpoint_name} {prompt_id} is {status}")
init_status = status
if status == 'success':
resp = api.get_execute_job_logs(headers=headers, prompt_id=prompt_id)
assert resp.status_code == 200, resp.dumps()
break
if status == InferenceStatus.FAILED.value:
logger.error(resp.json())
raise Exception(f"{n} {endpoint_name} {prompt_id} failed.")
else:
raise Exception(f"{n} {endpoint_name} {prompt_id} timed out after 5 minutes.")
def sd_inference_create(n, api, endpoint_name: str, workflow: str = './data/api_params/sd.json'):
headers = {
"x-api-key": config.api_key,
"username": config.username
}
data = {
"inference_type": "Async",
"task_type": InferenceType.TXT2IMG.value,
"workflow": 'latency_compare_sd',
"models": {
"Stable-diffusion": [config.default_model_id],
"embeddings": []
},
}
resp = api.create_inference(headers=headers, data=data)
assert resp.status_code == 201, resp.dumps()
inference_data = resp.json()['data']["inference"]
inference_id = inference_data["id"]
assert resp.json()["statusCode"] == 201
assert inference_data["type"] == InferenceType.TXT2IMG.value
assert len(inference_data["api_params_s3_upload_url"]) > 0
upload_with_put(inference_data["api_params_s3_upload_url"], workflow)
resp = api.get_inference_job(headers=headers, job_id=inference_data["id"])
assert resp.status_code == 200, resp.dumps()
resp = api.start_inference_job(job_id=inference_id, headers=headers)
assert resp.status_code == 202, resp.dumps()
assert resp.json()['data']["inference"]["status"] == InferenceStatus.INPROGRESS.value
timeout = datetime.now() + timedelta(minutes=2)
while datetime.now() < timeout:
status = get_inference_job_status(
api_instance=api,
job_id=inference_id
)
logger.info(f"sd {n} {endpoint_name} {inference_id} is {status}")
if status == InferenceStatus.SUCCEED.value:
break
if status == InferenceStatus.FAILED.value:
logger.error(inference_data)
break
time.sleep(4)
else:
raise Exception(f"Inference {inference_id} timed out after 2 minutes.")
def base64_image(image_url: str):
response = requests.get(image_url)
image_data = response.content
base64_encoded_image = base64.b64encode(image_data).decode('utf-8')
return base64_encoded_image
def sd_inference_esi(api, workflow: str = './data/api_params/extra-single-image-api-params.json', image: str = None):
headers = {
"x-api-key": config.api_key,
"username": config.username
}
data = {
"inference_type": "Async",
"task_type": InferenceType.ESI.value,
"workflow": 'esi',
"models": {
"Stable-diffusion": [config.default_model_id],
"embeddings": []
},
}
resp = api.create_inference(headers=headers, data=data)
assert resp.status_code == 201, resp.dumps()
inference_data = resp.json()['data']["inference"]
inference_id = inference_data["id"]
assert resp.json()["statusCode"] == 201
assert len(inference_data["api_params_s3_upload_url"]) > 0
with open(workflow, 'rb') as data:
if image:
data = json.loads(data.read())
data['image'] = image
response = requests.put(inference_data["api_params_s3_upload_url"], data=json.dumps(data))
response.raise_for_status()
resp = api.get_inference_job(headers=headers, job_id=inference_data["id"])
assert resp.status_code == 200, resp.dumps()
resp = api.start_inference_job(job_id=inference_id, headers=headers)
assert resp.status_code == 202, resp.dumps()
assert resp.json()['data']["inference"]["status"] == InferenceStatus.INPROGRESS.value
timeout = datetime.now() + timedelta(minutes=2)
while datetime.now() < timeout:
status = get_inference_job_status(
api_instance=api,
job_id=inference_id
)
logger.info(f"sd {inference_id} is {status}")
if status == InferenceStatus.SUCCEED.value:
break
if status == InferenceStatus.FAILED.value:
logger.error(inference_data)
break
time.sleep(4)
else:
raise Exception(f"Inference {inference_id} timed out after 2 minutes.")
return inference_id
def sd_inference_rembg(api, workflow: str = './data/api_params/rembg-api-params.json', image: str = None):
headers = {
"x-api-key": config.api_key,
"username": config.username
}
data = {
"inference_type": "Async",
"task_type": InferenceType.REMBG.value,
"workflow": 'rembg',
"models": {
"Stable-diffusion": [config.default_model_id],
"embeddings": []
},
}
resp = api.create_inference(headers=headers, data=data)
assert resp.status_code == 201, resp.dumps()
inference_data = resp.json()['data']["inference"]
inference_id = inference_data["id"]
assert resp.json()["statusCode"] == 201
assert len(inference_data["api_params_s3_upload_url"]) > 0
with open(workflow, 'rb') as data:
if image:
data = json.loads(data.read())
data['input_image'] = image
response = requests.put(inference_data["api_params_s3_upload_url"], data=json.dumps(data))
response.raise_for_status()
resp = api.get_inference_job(headers=headers, job_id=inference_data["id"])
assert resp.status_code == 200, resp.dumps()
resp = api.start_inference_job(job_id=inference_id, headers=headers)
assert resp.status_code == 202, resp.dumps()
assert resp.json()['data']["inference"]["status"] == InferenceStatus.INPROGRESS.value
timeout = datetime.now() + timedelta(minutes=2)
while datetime.now() < timeout:
status = get_inference_job_status(
api_instance=api,
job_id=inference_id
)
logger.info(f"sd {inference_id} is {status}")
if status == InferenceStatus.SUCCEED.value:
break
if status == InferenceStatus.FAILED.value:
logger.error(inference_data)
break
time.sleep(4)
else:
raise Exception(f"Inference {inference_id} timed out after 2 minutes.")
return inference_id
def get_endpoint_comfy_async(api):
return get_endpoint_by_prefix(api, "comfy-async-")
def get_endpoint_comfy_real_time(api):
return get_endpoint_by_prefix(api, "comfy-real-time-")
def get_endpoint_sd_async(api):
return get_endpoint_by_prefix(api, "sd-async-")
def get_endpoint_sd_real_time(api):
return get_endpoint_by_prefix(api, "sd-real-time-")
def get_endpoint_by_prefix(api, prefix: str):
endpoints = list_endpoints(api)
for endpoint in endpoints:
if endpoint['endpoint_name'].startswith(prefix):
return endpoint['endpoint_name']
raise Exception(f"{prefix}* endpoint not found")
def endpoints_wait_for_in_service(api, endpoint_name: str = None):
headers = {
"x-api-key": config.api_key,
"username": config.username
}
params = {
"username": config.username
}
resp = api.list_endpoints(headers=headers, params=params)
assert resp.status_code == 200, resp.dumps()
for endpoint in resp.json()['data']["endpoints"]:
if endpoint_name is not None and endpoint["endpoint_name"] != endpoint_name:
continue
endpoint_name = endpoint["endpoint_name"]
if endpoint["endpoint_status"] == "Failed":
raise Exception(f"{endpoint_name} is {endpoint['endpoint_status']}")
if endpoint["endpoint_status"] != "InService":
logger.info(f"{endpoint_name} is {endpoint['endpoint_status']}")
return False
else:
return True
return False
def check_s3_directory(directory):
try:
time.sleep(10)
paginator = s3.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=config.bucket, Delimiter='/')
for page in pages:
if 'CommonPrefixes' in page:
for prefix in page['CommonPrefixes']:
if prefix['Prefix'].endswith(directory):
raise Exception(f"cache *-{directory} still exists in {prefix['Prefix']}")
return False
except (NoCredentialsError, PartialCredentialsError):
print("Credentials not available.")
return False
except Exception as e:
print(f"An error occurred: {e}")
return False