218 lines
7.6 KiB
Python
218 lines
7.6 KiB
Python
import dataclasses
|
|
import os
|
|
from datetime import datetime
|
|
from unittest import TestCase
|
|
|
|
os.environ.setdefault('AWS_PROFILE', 'env')
|
|
os.environ.setdefault('S3_BUCKET', 'your-bucket')
|
|
os.environ.setdefault('DATASET_ITEM_TABLE', 'DatasetItemTable')
|
|
os.environ.setdefault('DATASET_INFO_TABLE', 'DatasetInfoTable')
|
|
os.environ.setdefault('MULTI_USER_TABLE', 'MultiUserTable')
|
|
|
|
os.environ.setdefault('TRAIN_TABLE', 'TrainingTable')
|
|
os.environ.setdefault('CHECKPOINT_TABLE', 'CheckpointTable')
|
|
os.environ.setdefault('SAGEMAKER_ENDPOINT_NAME', 'aigc-utils-endpoint')
|
|
|
|
os.environ.setdefault('DDB_ENDPOINT_DEPLOYMENT_TABLE_NAME', 'SDEndpointDeploymentJobTable')
|
|
os.environ.setdefault('INFERENCE_JOB_TABLE', 'SDInferenceJobTable')
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class MockContext:
|
|
aws_request_id: str
|
|
|
|
|
|
class InferenceApiTest(TestCase):
|
|
|
|
def test_get_checkpoint_by_name(self):
|
|
from inferences.inference_api import _get_checkpoint_by_name
|
|
ckpt = _get_checkpoint_by_name('v1-5-pruned-emaonly.safetensors', 'Stable-diffusion')
|
|
assert ckpt is not None
|
|
|
|
def test_prepare_inference(self):
|
|
from inferences.inference_api import prepare_inference
|
|
event = {
|
|
'user_id': 'admin',
|
|
'task_type': 'txt2img',
|
|
'models': {
|
|
'Stable-diffusion': ['v1-5-pruned-emaonly.safetensors'],
|
|
'VAE': ['vae-ft-mse-840000-ema-pruned.ckpt'], 'embeddings': []
|
|
},
|
|
'filters': {'createAt': 1696657891.055418, 'creator': 'sd-webui'}
|
|
}
|
|
|
|
_id = str(datetime.now().timestamp())
|
|
resp = prepare_inference(event, MockContext(aws_request_id=_id))
|
|
print(resp)
|
|
assert resp['statusCode'] == 200
|
|
# get the inference job from ddb by job id
|
|
|
|
from inferences.inference_api import inference_table_name, ddb_service
|
|
from inference_v2.types import InferenceJob
|
|
inference_raw = ddb_service.get_item(inference_table_name, {
|
|
'InferenceJobId': _id
|
|
})
|
|
inference_job = InferenceJob(**inference_raw)
|
|
models = {
|
|
"space_free_size": 4e10,
|
|
**inference_job.params['used_models'],
|
|
}
|
|
print(models)
|
|
|
|
def upload_with_put(url):
|
|
with open('api_param.json', 'rb') as file:
|
|
import requests
|
|
response = requests.put(url, data=file)
|
|
response.raise_for_status()
|
|
|
|
upload_with_put(resp['inference']['api_params_s3_upload_url'])
|
|
from inferences.inference_api import run_inference
|
|
resp = run_inference({
|
|
'pathStringParameters': {
|
|
'inference_id': _id
|
|
}
|
|
}, {})
|
|
print(resp)
|
|
|
|
def test_prepare_inference_img2img(self):
|
|
from inferences.inference_api import prepare_inference
|
|
event = {
|
|
'user_id': 'yuxiaox',
|
|
'task_type': 'txt2img',
|
|
'models': {
|
|
'Stable-diffusion': ['AnythingV5Ink_ink.safetensors'],
|
|
'embeddings': []},
|
|
'filters': {'createAt': 1695784940.13923, 'creator': 'sd-webui'}
|
|
}
|
|
_id = str(datetime.now().timestamp())
|
|
resp = prepare_inference(event, MockContext(aws_request_id=_id))
|
|
print(resp)
|
|
assert resp['statusCode'] == 200
|
|
# get the inference job from ddb by job id
|
|
|
|
from inferences.inference_api import inference_table_name, ddb_service
|
|
from inference_v2.types import InferenceJob
|
|
inference_raw = ddb_service.get_item(inference_table_name, {
|
|
'InferenceJobId': _id
|
|
})
|
|
inference_job = InferenceJob(**inference_raw)
|
|
models = {
|
|
"space_free_size": 4e10,
|
|
**inference_job.params['used_models'],
|
|
}
|
|
print(models)
|
|
|
|
def upload_with_put(url):
|
|
with open(
|
|
'/Users/cyanda/Dev/python-projects/stable-diffusion-webui/extensions/stable-diffusion-aws-extension/playground_NO_COMMIT/api_img2img_param.json',
|
|
'rb') as file:
|
|
import requests
|
|
response = requests.put(url, data=file)
|
|
response.raise_for_status()
|
|
|
|
upload_with_put(resp['inference']['api_params_s3_upload_url'])
|
|
from inferences.inference_api import run_inference
|
|
resp = run_inference({
|
|
'pathStringParameters': {
|
|
'inference_id': _id
|
|
}
|
|
}, {})
|
|
print(resp)
|
|
|
|
def test_run_infer(self):
|
|
from inferences.inference_api import run_inference
|
|
resp = run_inference({
|
|
'pathStringParameters': {
|
|
'inference_id': '2f5a14ba-44c1-438a-b369-ae1102b2dcab'
|
|
}
|
|
}, {})
|
|
print(resp)
|
|
|
|
def test_upload_infer(self):
|
|
def upload_with_put(url):
|
|
with open('api_param.json', 'rb') as file:
|
|
import requests
|
|
response = requests.put(url, data=file)
|
|
response.raise_for_status()
|
|
|
|
s3_presigned_url = 'https://presigned_s3_url'
|
|
upload_with_put(s3_presigned_url)
|
|
|
|
def test_split(self):
|
|
arg = {
|
|
'model': 'control_v11p_sd15_canny [d14c016b]'
|
|
}
|
|
model_parts = arg['model'].split()
|
|
print(' '.join(model_parts[:-1]))
|
|
|
|
def test_list_all_sagemaker_endpoints(self):
|
|
from inference_v2.sagemaker_endpoint_api import list_all_sagemaker_endpoints
|
|
resp = list_all_sagemaker_endpoints({
|
|
'queryStringParameters':
|
|
{
|
|
'username': 'spiderman'
|
|
},
|
|
'x-auth': {'username': 'spiderman', 'role': ''}}, {})
|
|
|
|
print(resp)
|
|
|
|
def test_list_all_inference_jobs(self):
|
|
from inferences.inference_api import list_all_inference_jobs
|
|
resp = list_all_inference_jobs({
|
|
'queryStringParameters': {
|
|
'username': 'mickey'
|
|
}
|
|
}, {})
|
|
|
|
print(resp)
|
|
|
|
def test_generate_extra_single(self):
|
|
self._do_generate_extra('extra-single-image', 'payload_extra_single.json')
|
|
|
|
def test_generate_extra_batch(self):
|
|
self._do_generate_extra('extra-batch-images', 'payload_extra_batch.json')
|
|
|
|
def test_generate_rembg(self):
|
|
self._do_generate_extra('rembg', 'payload_rembg.json')
|
|
|
|
def _do_generate_extra(self, _task_type, payload_url):
|
|
from inferences.inference_api import prepare_inference, run_inference
|
|
|
|
event = {
|
|
'user_id': 'admin',
|
|
'task_type': _task_type,
|
|
'models': {},
|
|
'filters': {'createAt': datetime.now().timestamp(), 'creator': 'sd-webui'}
|
|
}
|
|
resp = prepare_inference(event, MockContext(aws_request_id=f'{datetime.now().timestamp()}'))
|
|
print(resp)
|
|
assert resp['statusCode'] == 200
|
|
|
|
def upload_with_put(url, filename):
|
|
with open(filename, 'rb') as file:
|
|
import requests
|
|
response = requests.put(url, data=file)
|
|
response.raise_for_status()
|
|
|
|
upload_with_put(resp['inference']['api_params_s3_upload_url'], payload_url)
|
|
|
|
resp = run_inference({
|
|
'pathStringParameters': {
|
|
'inference_id': resp['inference']['id']
|
|
}
|
|
}, {})
|
|
print(resp)
|
|
assert resp['statusCode'] == 200
|
|
|
|
print(f"result s3 location: {resp['inference']['output_path']}")
|
|
|
|
def test_delete_endpoint(self):
|
|
from inference_v2.sagemaker_endpoint_api import delete_sagemaker_endpoints
|
|
resp = delete_sagemaker_endpoints({
|
|
"delete_endpoint_list": [
|
|
"infer-endpoint-dc-endpoint"
|
|
],
|
|
"username": "admin",
|
|
}, {})
|
|
print(resp)
|