stable-diffusion-aws-extension/test/test_05_api_checkpoint/test_03_upload_checkpoints.py

540 lines
18 KiB
Python

from __future__ import print_function
import logging
import config as config
from utils.api import Api
from utils.helper import upload_multipart_file, wget_file
logger = logging.getLogger(__name__)
checkpoint_id = None
signed_urls = None
class TestCheckPointE2E:
def setup_class(self):
self.api = Api(config)
self.api.feat_oas_schema()
@classmethod
def teardown_class(self):
pass
def test_0_clean_all_checkpoints(self):
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
resp = self.api.list_checkpoints(headers=headers).json()
checkpoints = resp['data']["checkpoints"]
id_list = []
for checkpoint in checkpoints:
id_list.append(checkpoint['id'])
if id_list:
data = {
"checkpoint_id_list": id_list
}
resp = self.api.delete_checkpoints(headers=headers, data=data)
assert resp.status_code == 204, resp.dumps()
def test_1_create_checkpoint_v15(self):
filename = "v1-5-pruned-emaonly.safetensors"
checkpoint_type = "Stable-diffusion"
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
data = {
"checkpoint_type": checkpoint_type,
"filenames": [
{
"filename": filename,
"parts_number": 5
}
],
"params": {
"message": config.ckpt_message,
"creator": config.username
},
"source_path": "/test/test_02_api_base",
"target_path": "/test/test_02_api_base"
}
resp = self.api.create_checkpoint(headers=headers, data=data)
assert resp.status_code == 201, resp.dumps()
assert resp.json()["statusCode"] == 201
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
assert len(resp.json()['data']["checkpoint"]['id']) == 36
global checkpoint_id
checkpoint_id = resp.json()['data']["checkpoint"]['id']
global signed_urls
signed_urls = resp.json()['data']["s3PresignUrl"][filename]
def test_2_update_checkpoint_v15_with_bad_params(self):
global checkpoint_id
headers = {
"x-api-key": config.api_key,
}
data = {
"status": "Active",
"name": ""
}
resp = self.api.update_checkpoint(checkpoint_id=checkpoint_id, headers=headers, data=data)
def test_3_update_checkpoint_v15(self):
filename = "v1-5-pruned-emaonly.safetensors"
local_path = f"/tmp/test/models/Stable-diffusion/{filename}"
wget_file(
local_path,
'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors',
'https://aws-gcr-solutions.s3.cn-north-1.amazonaws.com.cn/stable-diffusion-aws-extension-github-mainline/models/v1-5-pruned-emaonly.safetensors'
)
global signed_urls
multiparts_tags = upload_multipart_file(signed_urls, local_path)
checkpoint_type = "Stable-diffusion"
global checkpoint_id
data = {
"status": "Active",
"multi_parts_tags": {filename: multiparts_tags}
}
headers = {
"x-api-key": config.api_key,
}
resp = self.api.update_checkpoint(checkpoint_id=checkpoint_id, headers=headers, data=data)
assert resp.status_code == 200, resp.dumps()
assert resp.json()["statusCode"] == 200
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
def test_4_list_checkpoints_v15_check(self):
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
params = {
"username": config.username
}
resp = self.api.list_checkpoints(headers=headers, params=params)
assert resp.status_code == 200, resp.dumps()
global checkpoint_id
assert checkpoint_id in [checkpoint["id"] for checkpoint in resp.json()['data']["checkpoints"]]
def test_8_create_checkpoint_lora_nendoroid(self):
checkpoint_type = "Lora"
filename = "nendoroid_xl_v7.safetensors"
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
data = {
"checkpoint_type": checkpoint_type,
"filenames": [
{
"filename": filename,
"parts_number": 1
}
],
"params": {
"message": config.ckpt_message,
"creator": config.username
},
"source_path": "/test/test_02_api_base",
"target_path": "/test/test_02_api_base"
}
resp = self.api.create_checkpoint(headers=headers, data=data)
assert resp.status_code == 201, resp.dumps()
assert resp.json()["statusCode"] == 201
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
assert len(resp.json()['data']["checkpoint"]['id']) == 36
global checkpoint_id
checkpoint_id = resp.json()['data']["checkpoint"]['id']
global signed_urls
signed_urls = resp.json()['data']["s3PresignUrl"][filename]
def test_9_update_checkpoint_lora_nendoroid(self):
filename = "nendoroid_xl_v7.safetensors"
local_path = f"/tmp/test/models/Lora/{filename}"
wget_file(
local_path,
'https://aws-gcr-solutions.s3.amazonaws.com/stable-diffusion-aws-extension-github-mainline/models/nendoroid_xl_v7.safetensors',
'https://aws-gcr-solutions.s3.cn-north-1.amazonaws.com.cn/stable-diffusion-aws-extension-github-mainline/models/nendoroid_xl_v7.safetensors'
)
checkpoint_type = "Lora"
global signed_urls
multiparts_tags = upload_multipart_file(signed_urls, local_path)
global checkpoint_id
data = {
"status": "Active",
"multi_parts_tags": {filename: multiparts_tags}
}
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
resp = self.api.update_checkpoint(checkpoint_id=checkpoint_id, headers=headers, data=data)
assert resp.status_code == 200, resp.dumps()
assert resp.json()["statusCode"] == 200
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
def test_10_list_checkpoint_lora_nendoroid_check(self):
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
params = {
"username": config.username
}
resp = self.api.list_checkpoints(headers=headers, params=params)
assert resp.status_code == 200, resp.dumps()
global checkpoint_id
assert checkpoint_id in [checkpoint["id"] for checkpoint in resp.json()['data']["checkpoints"]]
def test_8_create_checkpoint_lora_lcm_1_5(self):
checkpoint_type = "Lora"
filename = "lcm_lora_1_5.safetensors"
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
data = {
"checkpoint_type": checkpoint_type,
"filenames": [
{
"filename": filename,
"parts_number": 1
}
],
"params": {
"message": config.ckpt_message,
"creator": config.username
},
"source_path": "/test/test_02_api_base",
"target_path": "/test/test_02_api_base"
}
resp = self.api.create_checkpoint(headers=headers, data=data)
assert resp.status_code == 201, resp.dumps()
assert resp.json()["statusCode"] == 201
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
assert len(resp.json()['data']["checkpoint"]['id']) == 36
global checkpoint_id
checkpoint_id = resp.json()['data']["checkpoint"]['id']
global signed_urls
signed_urls = resp.json()['data']["s3PresignUrl"][filename]
def test_9_update_checkpoint_lora_lcm_1_5(self):
filename = "lcm_lora_1_5.safetensors"
local_path = f"/tmp/test/models/Lora/{filename}"
wget_file(
local_path,
'https://aws-gcr-solutions-us-east-1.s3.us-east-1.amazonaws.com/extension-for-stable-diffusion-on-aws/models/Lora/lcm_lora_1_5.safetensors'
)
checkpoint_type = "Lora"
global signed_urls
multiparts_tags = upload_multipart_file(signed_urls, local_path)
global checkpoint_id
data = {
"status": "Active",
"multi_parts_tags": {filename: multiparts_tags}
}
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
resp = self.api.update_checkpoint(checkpoint_id=checkpoint_id, headers=headers, data=data)
assert resp.status_code == 200, resp.dumps()
assert resp.json()["statusCode"] == 200
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
def test_10_list_checkpoint_lora_lcm_1_5_check(self):
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
params = {
"username": config.username
}
resp = self.api.list_checkpoints(headers=headers, params=params)
assert resp.status_code == 200, resp.dumps()
global checkpoint_id
assert checkpoint_id in [checkpoint["id"] for checkpoint in resp.json()['data']["checkpoints"]]
def test_8_create_checkpoint_lora_lcm_xl(self):
checkpoint_type = "Lora"
filename = "lcm_lora_xl.safetensors"
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
data = {
"checkpoint_type": checkpoint_type,
"filenames": [
{
"filename": filename,
"parts_number": 1
}
],
"params": {
"message": config.ckpt_message,
"creator": config.username
},
"source_path": "/test/test_02_api_base",
"target_path": "/test/test_02_api_base"
}
resp = self.api.create_checkpoint(headers=headers, data=data)
assert resp.status_code == 201, resp.dumps()
assert resp.json()["statusCode"] == 201
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
assert len(resp.json()['data']["checkpoint"]['id']) == 36
global checkpoint_id
checkpoint_id = resp.json()['data']["checkpoint"]['id']
global signed_urls
signed_urls = resp.json()['data']["s3PresignUrl"][filename]
def test_9_update_checkpoint_lora_lcm_xl(self):
filename = "lcm_lora_xl.safetensors"
local_path = f"/tmp/test/models/Lora/{filename}"
wget_file(
local_path,
'https://aws-gcr-solutions-us-east-1.s3.us-east-1.amazonaws.com/extension-for-stable-diffusion-on-aws/models/Lora/lcm_lora_xl.safetensors'
)
checkpoint_type = "Lora"
global signed_urls
multiparts_tags = upload_multipart_file(signed_urls, local_path)
global checkpoint_id
data = {
"status": "Active",
"multi_parts_tags": {filename: multiparts_tags}
}
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
resp = self.api.update_checkpoint(checkpoint_id=checkpoint_id, headers=headers, data=data)
assert resp.status_code == 200, resp.dumps()
assert resp.json()["statusCode"] == 200
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
def test_10_list_checkpoint_lora_lcm_xl_check(self):
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
params = {
"username": config.username
}
resp = self.api.list_checkpoints(headers=headers, params=params)
assert resp.status_code == 200, resp.dumps()
global checkpoint_id
assert checkpoint_id in [checkpoint["id"] for checkpoint in resp.json()['data']["checkpoints"]]
def test_11_create_checkpoint_canny(self):
checkpoint_type = "ControlNet"
filename = "control_v11p_sd15_canny.pth"
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
data = {
"checkpoint_type": checkpoint_type,
"filenames": [
{
"filename": filename,
"parts_number": 2
}
],
"params": {
"message": config.ckpt_message,
"creator": config.username
},
"source_path": "/test/test_02_api_base",
"target_path": "/test/test_02_api_base"
}
resp = self.api.create_checkpoint(headers=headers, data=data)
assert resp.status_code == 201, resp.dumps()
assert resp.json()["statusCode"] == 201
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
assert len(resp.json()['data']["checkpoint"]['id']) == 36
global checkpoint_id
checkpoint_id = resp.json()['data']["checkpoint"]['id']
global signed_urls
signed_urls = resp.json()['data']["s3PresignUrl"][filename]
def test_12_update_checkpoint_canny(self):
filename = "control_v11p_sd15_canny.pth"
local_path = f"/tmp/test/models/ControlNet/{filename}"
wget_file(
local_path,
'https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny.pth',
'https://aws-gcr-solutions.s3.cn-north-1.amazonaws.com.cn/stable-diffusion-aws-extension-github-mainline/models/control_v11p_sd15_canny.pth'
)
checkpoint_type = "ControlNet"
global signed_urls
multiparts_tags = upload_multipart_file(signed_urls, local_path)
global checkpoint_id
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
data = {
"status": "Active",
"multi_parts_tags": {filename: multiparts_tags}
}
resp = self.api.update_checkpoint(checkpoint_id=checkpoint_id, headers=headers, data=data)
assert resp.status_code == 200, resp.dumps()
assert resp.json()["statusCode"] == 200
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
def test_13_list_checkpoints_canny_check(self):
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
params = {
"username": config.username
}
resp = self.api.list_checkpoints(headers=headers, params=params)
assert resp.status_code == 200, resp.dumps()
global checkpoint_id
assert checkpoint_id in [checkpoint["id"] for checkpoint in resp.json()['data']["checkpoints"]]
def test_14_create_checkpoint_openpose(self):
checkpoint_type = "ControlNet"
filename = "control_v11p_sd15_openpose.pth"
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
data = {
"checkpoint_type": checkpoint_type,
"filenames": [
{
"filename": filename,
"parts_number": 2
}
],
"params": {
"message": config.ckpt_message,
"creator": config.username
},
"source_path": "/test/test_02_api_base",
"target_path": "/test/test_02_api_base"
}
resp = self.api.create_checkpoint(headers=headers, data=data)
assert resp.status_code == 201, resp.dumps()
assert resp.json()["statusCode"] == 201
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
assert len(resp.json()['data']["checkpoint"]['id']) == 36
global checkpoint_id
checkpoint_id = resp.json()['data']["checkpoint"]['id']
global signed_urls
signed_urls = resp.json()['data']["s3PresignUrl"][filename]
def test_15_update_checkpoint_openpose(self):
filename = "control_v11p_sd15_openpose.pth"
local_path = f"/tmp/test/models/ControlNet/{filename}"
wget_file(
local_path,
'https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_openpose.pth',
'https://aws-gcr-solutions.s3.cn-north-1.amazonaws.com.cn/stable-diffusion-aws-extension-github-mainline/models/control_v11p_sd15_openpose.pth'
)
checkpoint_type = "ControlNet"
global signed_urls
multiparts_tags = upload_multipart_file(signed_urls, local_path)
global checkpoint_id
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
data = {
"status": "Active",
"multi_parts_tags": {filename: multiparts_tags}
}
resp = self.api.update_checkpoint(checkpoint_id=checkpoint_id, headers=headers, data=data)
assert resp.status_code == 200, resp.dumps()
assert resp.json()["statusCode"] == 200
assert resp.json()['data']["checkpoint"]['type'] == checkpoint_type
def test_16_list_checkpoints_openpose_check(self):
headers = {
"x-api-key": config.api_key,
"username": config.username,
}
params = {
"username": config.username
}
resp = self.api.list_checkpoints(headers=headers, params=params)
assert resp.status_code == 200, resp.dumps()
global checkpoint_id
assert checkpoint_id in [checkpoint["id"] for checkpoint in resp.json()['data']["checkpoints"]]