540 lines
18 KiB
Python
540 lines
18 KiB
Python
from __future__ import print_function
|
|
|
|
import logging
|
|
|
|
import config as config
|
|
from utils.api import Api
|
|
from utils.helper import upload_multipart_file, wget_file
|
|
|
|
logger = logging.getLogger(__name__)
|
|
checkpoint_id = None
|
|
signed_urls = None
|
|
|
|
|
|
class TestCheckPointE2E:
|
|
|
|
def setup_class(self):
|
|
self.api = Api(config)
|
|
self.api.feat_oas_schema()
|
|
|
|
@classmethod
|
|
def teardown_class(self):
|
|
pass
|
|
|
|
def test_0_clean_all_checkpoints(self):
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
resp = self.api.list_checkpoints(headers=headers).json()
|
|
checkpoints = resp['data']["checkpoints"]
|
|
|
|
id_list = []
|
|
for checkpoint in checkpoints:
|
|
id_list.append(checkpoint['id'])
|
|
|
|
if id_list:
|
|
data = {
|
|
"checkpoint_id_list": id_list
|
|
}
|
|
resp = self.api.delete_checkpoints(headers=headers, data=data)
|
|
assert resp.status_code == 204, resp.dumps()
|
|
|
|
def test_1_create_checkpoint_v15(self):
|
|
filename = "v1-5-pruned-emaonly.safetensors"
|
|
checkpoint_type = "Stable-diffusion"
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
data = {
|
|
"checkpoint_type": checkpoint_type,
|
|
"filenames": [
|
|
{
|
|
"filename": filename,
|
|
"parts_number": 5
|
|
}
|
|
],
|
|
"params": {
|
|
"message": config.ckpt_message,
|
|
"creator": config.username
|
|
},
|
|
"source_path": "/test/test_02_api_base",
|
|
"target_path": "/test/test_02_api_base"
|
|
}
|
|
|
|
resp = self.api.create_checkpoint(headers=headers, data=data)
|
|
|
|
assert resp.status_code == 201, resp.dumps()
|
|
assert resp.json()["statusCode"] == 201
|
|
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
|
|
assert len(resp.json()['data']["checkpoint"]['id']) == 36
|
|
|
|
global checkpoint_id
|
|
checkpoint_id = resp.json()['data']["checkpoint"]['id']
|
|
global signed_urls
|
|
signed_urls = resp.json()['data']["s3PresignUrl"][filename]
|
|
|
|
def test_2_update_checkpoint_v15_with_bad_params(self):
|
|
global checkpoint_id
|
|
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
}
|
|
|
|
data = {
|
|
"status": "Active",
|
|
"name": ""
|
|
}
|
|
|
|
resp = self.api.update_checkpoint(checkpoint_id=checkpoint_id, headers=headers, data=data)
|
|
|
|
def test_3_update_checkpoint_v15(self):
|
|
filename = "v1-5-pruned-emaonly.safetensors"
|
|
local_path = f"/tmp/test/models/Stable-diffusion/{filename}"
|
|
wget_file(
|
|
local_path,
|
|
'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors',
|
|
'https://aws-gcr-solutions.s3.cn-north-1.amazonaws.com.cn/stable-diffusion-aws-extension-github-mainline/models/v1-5-pruned-emaonly.safetensors'
|
|
)
|
|
global signed_urls
|
|
multiparts_tags = upload_multipart_file(signed_urls, local_path)
|
|
checkpoint_type = "Stable-diffusion"
|
|
|
|
global checkpoint_id
|
|
|
|
data = {
|
|
"status": "Active",
|
|
"multi_parts_tags": {filename: multiparts_tags}
|
|
}
|
|
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
}
|
|
|
|
resp = self.api.update_checkpoint(checkpoint_id=checkpoint_id, headers=headers, data=data)
|
|
|
|
assert resp.status_code == 200, resp.dumps()
|
|
assert resp.json()["statusCode"] == 200
|
|
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
|
|
|
|
def test_4_list_checkpoints_v15_check(self):
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
params = {
|
|
"username": config.username
|
|
}
|
|
|
|
resp = self.api.list_checkpoints(headers=headers, params=params)
|
|
assert resp.status_code == 200, resp.dumps()
|
|
global checkpoint_id
|
|
assert checkpoint_id in [checkpoint["id"] for checkpoint in resp.json()['data']["checkpoints"]]
|
|
|
|
def test_8_create_checkpoint_lora_nendoroid(self):
|
|
checkpoint_type = "Lora"
|
|
filename = "nendoroid_xl_v7.safetensors"
|
|
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
data = {
|
|
"checkpoint_type": checkpoint_type,
|
|
"filenames": [
|
|
{
|
|
"filename": filename,
|
|
"parts_number": 1
|
|
}
|
|
],
|
|
"params": {
|
|
"message": config.ckpt_message,
|
|
"creator": config.username
|
|
},
|
|
"source_path": "/test/test_02_api_base",
|
|
"target_path": "/test/test_02_api_base"
|
|
}
|
|
|
|
resp = self.api.create_checkpoint(headers=headers, data=data)
|
|
|
|
assert resp.status_code == 201, resp.dumps()
|
|
assert resp.json()["statusCode"] == 201
|
|
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
|
|
assert len(resp.json()['data']["checkpoint"]['id']) == 36
|
|
global checkpoint_id
|
|
checkpoint_id = resp.json()['data']["checkpoint"]['id']
|
|
global signed_urls
|
|
signed_urls = resp.json()['data']["s3PresignUrl"][filename]
|
|
|
|
def test_9_update_checkpoint_lora_nendoroid(self):
|
|
filename = "nendoroid_xl_v7.safetensors"
|
|
local_path = f"/tmp/test/models/Lora/{filename}"
|
|
wget_file(
|
|
local_path,
|
|
'https://aws-gcr-solutions.s3.amazonaws.com/stable-diffusion-aws-extension-github-mainline/models/nendoroid_xl_v7.safetensors',
|
|
'https://aws-gcr-solutions.s3.cn-north-1.amazonaws.com.cn/stable-diffusion-aws-extension-github-mainline/models/nendoroid_xl_v7.safetensors'
|
|
)
|
|
checkpoint_type = "Lora"
|
|
global signed_urls
|
|
multiparts_tags = upload_multipart_file(signed_urls, local_path)
|
|
global checkpoint_id
|
|
|
|
data = {
|
|
"status": "Active",
|
|
"multi_parts_tags": {filename: multiparts_tags}
|
|
}
|
|
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
resp = self.api.update_checkpoint(checkpoint_id=checkpoint_id, headers=headers, data=data)
|
|
|
|
assert resp.status_code == 200, resp.dumps()
|
|
assert resp.json()["statusCode"] == 200
|
|
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
|
|
|
|
def test_10_list_checkpoint_lora_nendoroid_check(self):
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
params = {
|
|
"username": config.username
|
|
}
|
|
|
|
resp = self.api.list_checkpoints(headers=headers, params=params)
|
|
|
|
assert resp.status_code == 200, resp.dumps()
|
|
global checkpoint_id
|
|
assert checkpoint_id in [checkpoint["id"] for checkpoint in resp.json()['data']["checkpoints"]]
|
|
|
|
def test_8_create_checkpoint_lora_lcm_1_5(self):
|
|
checkpoint_type = "Lora"
|
|
filename = "lcm_lora_1_5.safetensors"
|
|
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
data = {
|
|
"checkpoint_type": checkpoint_type,
|
|
"filenames": [
|
|
{
|
|
"filename": filename,
|
|
"parts_number": 1
|
|
}
|
|
],
|
|
"params": {
|
|
"message": config.ckpt_message,
|
|
"creator": config.username
|
|
},
|
|
"source_path": "/test/test_02_api_base",
|
|
"target_path": "/test/test_02_api_base"
|
|
}
|
|
|
|
resp = self.api.create_checkpoint(headers=headers, data=data)
|
|
|
|
assert resp.status_code == 201, resp.dumps()
|
|
assert resp.json()["statusCode"] == 201
|
|
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
|
|
assert len(resp.json()['data']["checkpoint"]['id']) == 36
|
|
global checkpoint_id
|
|
checkpoint_id = resp.json()['data']["checkpoint"]['id']
|
|
global signed_urls
|
|
signed_urls = resp.json()['data']["s3PresignUrl"][filename]
|
|
|
|
def test_9_update_checkpoint_lora_lcm_1_5(self):
|
|
filename = "lcm_lora_1_5.safetensors"
|
|
local_path = f"/tmp/test/models/Lora/{filename}"
|
|
wget_file(
|
|
local_path,
|
|
'https://aws-gcr-solutions-us-east-1.s3.us-east-1.amazonaws.com/extension-for-stable-diffusion-on-aws/models/Lora/lcm_lora_1_5.safetensors'
|
|
)
|
|
checkpoint_type = "Lora"
|
|
global signed_urls
|
|
multiparts_tags = upload_multipart_file(signed_urls, local_path)
|
|
global checkpoint_id
|
|
|
|
data = {
|
|
"status": "Active",
|
|
"multi_parts_tags": {filename: multiparts_tags}
|
|
}
|
|
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
resp = self.api.update_checkpoint(checkpoint_id=checkpoint_id, headers=headers, data=data)
|
|
|
|
assert resp.status_code == 200, resp.dumps()
|
|
assert resp.json()["statusCode"] == 200
|
|
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
|
|
|
|
def test_10_list_checkpoint_lora_lcm_1_5_check(self):
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
params = {
|
|
"username": config.username
|
|
}
|
|
|
|
resp = self.api.list_checkpoints(headers=headers, params=params)
|
|
|
|
assert resp.status_code == 200, resp.dumps()
|
|
global checkpoint_id
|
|
assert checkpoint_id in [checkpoint["id"] for checkpoint in resp.json()['data']["checkpoints"]]
|
|
|
|
def test_8_create_checkpoint_lora_lcm_xl(self):
|
|
checkpoint_type = "Lora"
|
|
filename = "lcm_lora_xl.safetensors"
|
|
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
data = {
|
|
"checkpoint_type": checkpoint_type,
|
|
"filenames": [
|
|
{
|
|
"filename": filename,
|
|
"parts_number": 1
|
|
}
|
|
],
|
|
"params": {
|
|
"message": config.ckpt_message,
|
|
"creator": config.username
|
|
},
|
|
"source_path": "/test/test_02_api_base",
|
|
"target_path": "/test/test_02_api_base"
|
|
}
|
|
|
|
resp = self.api.create_checkpoint(headers=headers, data=data)
|
|
|
|
assert resp.status_code == 201, resp.dumps()
|
|
assert resp.json()["statusCode"] == 201
|
|
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
|
|
assert len(resp.json()['data']["checkpoint"]['id']) == 36
|
|
global checkpoint_id
|
|
checkpoint_id = resp.json()['data']["checkpoint"]['id']
|
|
global signed_urls
|
|
signed_urls = resp.json()['data']["s3PresignUrl"][filename]
|
|
|
|
def test_9_update_checkpoint_lora_lcm_xl(self):
|
|
filename = "lcm_lora_xl.safetensors"
|
|
local_path = f"/tmp/test/models/Lora/{filename}"
|
|
wget_file(
|
|
local_path,
|
|
'https://aws-gcr-solutions-us-east-1.s3.us-east-1.amazonaws.com/extension-for-stable-diffusion-on-aws/models/Lora/lcm_lora_xl.safetensors'
|
|
)
|
|
checkpoint_type = "Lora"
|
|
global signed_urls
|
|
multiparts_tags = upload_multipart_file(signed_urls, local_path)
|
|
global checkpoint_id
|
|
|
|
data = {
|
|
"status": "Active",
|
|
"multi_parts_tags": {filename: multiparts_tags}
|
|
}
|
|
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
resp = self.api.update_checkpoint(checkpoint_id=checkpoint_id, headers=headers, data=data)
|
|
|
|
assert resp.status_code == 200, resp.dumps()
|
|
assert resp.json()["statusCode"] == 200
|
|
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
|
|
|
|
def test_10_list_checkpoint_lora_lcm_xl_check(self):
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
params = {
|
|
"username": config.username
|
|
}
|
|
|
|
resp = self.api.list_checkpoints(headers=headers, params=params)
|
|
|
|
assert resp.status_code == 200, resp.dumps()
|
|
global checkpoint_id
|
|
assert checkpoint_id in [checkpoint["id"] for checkpoint in resp.json()['data']["checkpoints"]]
|
|
|
|
def test_11_create_checkpoint_canny(self):
|
|
checkpoint_type = "ControlNet"
|
|
filename = "control_v11p_sd15_canny.pth"
|
|
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
data = {
|
|
"checkpoint_type": checkpoint_type,
|
|
"filenames": [
|
|
{
|
|
"filename": filename,
|
|
"parts_number": 2
|
|
}
|
|
],
|
|
"params": {
|
|
"message": config.ckpt_message,
|
|
"creator": config.username
|
|
},
|
|
"source_path": "/test/test_02_api_base",
|
|
"target_path": "/test/test_02_api_base"
|
|
}
|
|
|
|
resp = self.api.create_checkpoint(headers=headers, data=data)
|
|
|
|
assert resp.status_code == 201, resp.dumps()
|
|
assert resp.json()["statusCode"] == 201
|
|
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
|
|
assert len(resp.json()['data']["checkpoint"]['id']) == 36
|
|
global checkpoint_id
|
|
checkpoint_id = resp.json()['data']["checkpoint"]['id']
|
|
global signed_urls
|
|
signed_urls = resp.json()['data']["s3PresignUrl"][filename]
|
|
|
|
def test_12_update_checkpoint_canny(self):
|
|
filename = "control_v11p_sd15_canny.pth"
|
|
local_path = f"/tmp/test/models/ControlNet/{filename}"
|
|
wget_file(
|
|
local_path,
|
|
'https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny.pth',
|
|
'https://aws-gcr-solutions.s3.cn-north-1.amazonaws.com.cn/stable-diffusion-aws-extension-github-mainline/models/control_v11p_sd15_canny.pth'
|
|
)
|
|
checkpoint_type = "ControlNet"
|
|
global signed_urls
|
|
multiparts_tags = upload_multipart_file(signed_urls, local_path)
|
|
global checkpoint_id
|
|
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
data = {
|
|
"status": "Active",
|
|
"multi_parts_tags": {filename: multiparts_tags}
|
|
}
|
|
|
|
resp = self.api.update_checkpoint(checkpoint_id=checkpoint_id, headers=headers, data=data)
|
|
|
|
assert resp.status_code == 200, resp.dumps()
|
|
assert resp.json()["statusCode"] == 200
|
|
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
|
|
|
|
def test_13_list_checkpoints_canny_check(self):
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
params = {
|
|
"username": config.username
|
|
}
|
|
|
|
resp = self.api.list_checkpoints(headers=headers, params=params)
|
|
|
|
assert resp.status_code == 200, resp.dumps()
|
|
global checkpoint_id
|
|
assert checkpoint_id in [checkpoint["id"] for checkpoint in resp.json()['data']["checkpoints"]]
|
|
|
|
def test_14_create_checkpoint_openpose(self):
|
|
checkpoint_type = "ControlNet"
|
|
filename = "control_v11p_sd15_openpose.pth"
|
|
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
data = {
|
|
"checkpoint_type": checkpoint_type,
|
|
"filenames": [
|
|
{
|
|
"filename": filename,
|
|
"parts_number": 2
|
|
}
|
|
],
|
|
"params": {
|
|
"message": config.ckpt_message,
|
|
"creator": config.username
|
|
},
|
|
"source_path": "/test/test_02_api_base",
|
|
"target_path": "/test/test_02_api_base"
|
|
}
|
|
|
|
resp = self.api.create_checkpoint(headers=headers, data=data)
|
|
|
|
assert resp.status_code == 201, resp.dumps()
|
|
assert resp.json()["statusCode"] == 201
|
|
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
|
|
assert len(resp.json()['data']["checkpoint"]['id']) == 36
|
|
global checkpoint_id
|
|
checkpoint_id = resp.json()['data']["checkpoint"]['id']
|
|
global signed_urls
|
|
signed_urls = resp.json()['data']["s3PresignUrl"][filename]
|
|
|
|
def test_15_update_checkpoint_openpose(self):
|
|
filename = "control_v11p_sd15_openpose.pth"
|
|
local_path = f"/tmp/test/models/ControlNet/{filename}"
|
|
wget_file(
|
|
local_path,
|
|
'https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_openpose.pth',
|
|
'https://aws-gcr-solutions.s3.cn-north-1.amazonaws.com.cn/stable-diffusion-aws-extension-github-mainline/models/control_v11p_sd15_openpose.pth'
|
|
)
|
|
checkpoint_type = "ControlNet"
|
|
global signed_urls
|
|
multiparts_tags = upload_multipart_file(signed_urls, local_path)
|
|
global checkpoint_id
|
|
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
data = {
|
|
"status": "Active",
|
|
"multi_parts_tags": {filename: multiparts_tags}
|
|
}
|
|
|
|
resp = self.api.update_checkpoint(checkpoint_id=checkpoint_id, headers=headers, data=data)
|
|
|
|
assert resp.status_code == 200, resp.dumps()
|
|
assert resp.json()["statusCode"] == 200
|
|
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
|
|
|
|
def test_16_list_checkpoints_openpose_check(self):
|
|
headers = {
|
|
"x-api-key": config.api_key,
|
|
"username": config.username,
|
|
}
|
|
|
|
params = {
|
|
"username": config.username
|
|
}
|
|
|
|
resp = self.api.list_checkpoints(headers=headers, params=params)
|
|
|
|
assert resp.status_code == 200, resp.dumps()
|
|
global checkpoint_id
|
|
assert checkpoint_id in [checkpoint["id"] for checkpoint in resp.json()['data']["checkpoints"]]
|