import dataclasses import os from decimal import Decimal from unittest import TestCase import requests from checkpoints.create_checkpoint import get_real_url os.environ.setdefault('AWS_PROFILE', 'aws_profile') os.environ.setdefault('S3_BUCKET_NAME', 'bucket') os.environ.setdefault('DYNAMODB_TABLE', 'ModelTable') os.environ.setdefault('MODEL_TABLE', 'ModelTable') os.environ.setdefault('TRAIN_TABLE', 'TrainingTable') os.environ.setdefault('CHECKPOINT_TABLE', 'CheckpointTable') os.environ.setdefault('SAGEMAKER_ENDPOINT_NAME', 'aigc-utils-endpoint') os.environ.setdefault('MULTI_USER_TABLE', 'MultiUserTable') @dataclasses.dataclass class MockContext: aws_request_id: str class ModelsApiTest(TestCase): def test_get_real_url(self): real_url = get_real_url( "https://civitai.com/api/download/models/275491?type=Model&format=SafeTensor&size=full&fp=fp16") assert 'civitai-delivery-worker-prod' in real_url def test_get_real_url_file(self): url = "https://aws-gcr-solutions.s3.cn-north-1.amazonaws.com.cn/stable-diffusion-aws-extension-github-mainline/models/v1-5-pruned-emaonly.safetensors" real_url = get_real_url(url) assert real_url == url def test_upload(self): from models.model_api import create_model_api resp = create_model_api({ "model_type": "dreambooth", "name": "test_upload", "filenames": [{ "filename": 'test1', "parts_number": 5 }], "params": {} }, MockContext(aws_request_id="asdfasdf")) print(resp) def upload_with_put(url): with open('test_create_model_api.py', 'rb') as file: response = requests.put(url, data=file) response.raise_for_status() upload_with_put(resp['s3PresignUrl']) def test_upload_2(self): url = "presign s3 url" def upload_with_put(url): with open('file.tar.gz', 'rb') as file: response = requests.put(url, data=file) response.raise_for_status() upload_with_put(url) def test_model_update(self): from models.model_api import update_model_job_api update_model_job_api({ 'model_id': 'asdfasdf', 'status': 'Creating', 'multi_parts_tags': {"test1": [{'ETag': '"cc95c41fa28463c8e9b88d67805f24e0"', 'PartNumber': 1}]}, }, {}) def test_process(self): data = {} # sample data from models.model_api import process_result process_result(data, {}) def test_convert(self): d = Decimal(4) from common.ddb_service.client import DynamoDbUtilsService obj = DynamoDbUtilsService._convert(d) print(obj) def test_list_all(self): from models.model_api import list_all_models_api resp = list_all_models_api({ 'queryStringParameters': { }, 'x-auth': { 'username': 'admin', 'role': '' } }, {}) print(resp) def test_s3(self): # split s3://alvindaiyan-aigc-testing-playground/models/7a77d369-142c-4091-91e1-9278566a6a4f.out from models.model_api import split_s3_path bucket, key = split_s3_path('s3://path') from models.model_api import get_object get_object(bucket=bucket, key=key) def test_list_checkpoints(self): from model_and_train.checkpoint_api import list_all_checkpoints_api resp = list_all_checkpoints_api({ 'queryStringParameters': {}, 'x-auth': { 'username': 'xman', 'role': [] } }, {}) print(resp) def test_list_train_jobs(self): from trainings.process_train_job_result import list_all_train_jobs_api resp = list_all_train_jobs_api({ 'queryStringParameters': { }, 'x-auth': { 'username': 'xman', 'role': [] } }, {}) print(resp) def test_create_update_checkpoint(self): from checkpoint_api import update_checkpoint_api # resp = create_checkpoint_api({ # "checkpoint_type": "dreambooth", # "filenames": [ # {"filename": "test1", "parts_number": 5} # ], # "params": { # "new_model_name": "test_api", # "number": 1, # "string": "abc" # } # }, MockContext(aws_request_id="asdfasdf")) # print(resp) resp = update_checkpoint_api({ "checkpoint_id": "4e5118f5-9d9a-4a7e-aace-6f5e52c4edd9", "status": "Active", 'multi_parts_tags': {"test1": [{'ETag': '"cc95c41fa28463c8e9b88d67805f24e0"', 'PartNumber': 1}]}, }, {}) print(resp) def test_update_train_job_api(self): from trainings.process_train_job_result import update_train_job_api update_train_job_api({ "train_job_id": "asdfasdf", "status": "Training" }, {}) def test_check_train_job_status(self): from trainings.process_train_job_result import check_train_job_status event = {'train_job_id': 'd0c19f0a-1c0f-4ac9-b7ea-6b0be8a889d0', 'train_job_name': 'test-new-local-2023-07-14-06-15-59-724'} check_train_job_status(event, {}) def test_scan(self): import logging from common.ddb_service.client import DynamoDbUtilsService logger = logging.getLogger('boto3') ddb_service = DynamoDbUtilsService(logger=logger) resp = ddb_service.scan(table='ModelTable', filters={ 'model_type': 'dreambooth', # 'job_status': ['Initial', 'Creating', 'Complete'] }) print(resp) def test_none(self): import logging from common.ddb_service.client import DynamoDbUtilsService logger = logging.getLogger('boto3') ddb_service = DynamoDbUtilsService(logger=logger) ddb_service.put_items(table='ModelTable', entries={ 'id': '512d5e64-021e-49f5-a313-227f842c3f93', 'name': 'testProgressBar01', 'checkpoint_id': '512d5e64-021e-49f5-a313-227f842c3f93', 'model_type': 'dreambooth', 'job_status': 'Initial', 'output_s3_location': 's3://placeholder.s3/dreambooth/model/testProgressBar01/512d5e64-021e-49f5-a313-227f842c3f93/output', 'params': {'create_model_params': {'new_model_name': 'testProgressBar01', 'ckpt_path': 'v1-5-pruned-emaonly.safetensors', 'from_hub': False, 'new_model_url': '', 'new_model_token': '', 'extract_ema': False, 'train_unfrozen': False, 'is_512': True, 'sh': None}}}) resp = ddb_service._convert({ 'params': {'test': None} }) print(resp) def test_multipart(self): import boto3 from botocore.config import Config import requests import math s3 = boto3.client('s3', config=Config(signature_version='s3v4')) bucket = 'bucketname' key = 'test_multipart/tmp_10M_file' large_file_location = '/Users/cyanda/Dev/remote/tmp_10M_file' part_size = 1 * 1024 * 1024 file_size = os.stat(large_file_location) print(file_size.st_size) parts_number = math.ceil(file_size.st_size / part_size) # parts = 5 print(parts_number) from common_tools import get_s3_multipart_signed_urls presign_url_resp = get_s3_multipart_signed_urls(bucket, key, parts_number) presign_urls = presign_url_resp['s3_signed_urls'] upload_id = presign_url_resp['UploadId'] with open(large_file_location, 'rb') as f: parts = [] try: for i, signed_url in enumerate(presign_urls): file_data = f.read(part_size) response = requests.put(signed_url, data=file_data) etag = response.headers['ETag'] parts.append({ 'ETag': etag, 'PartNumber': i + 1 }) parts.sort(key=lambda x: x['PartNumber']) response = s3.complete_multipart_upload( Bucket=bucket, Key=key, MultipartUpload={'Parts': parts}, UploadId=upload_id ) print(response) except Exception as e: print(e) finally: response = s3.abort_multipart_upload( Bucket=bucket, Key=key, UploadId=upload_id ) print(response) def test_batch_get_s3_multipart_signed_urls(self): from model_and_train.common_tools import batch_get_s3_multipart_signed_urls from model_and_train.types import MultipartFileReq resp = batch_get_s3_multipart_signed_urls( 'bucket', 'test-multipart-api', [MultipartFileReq(filename='name_not_matter', parts_number=5)] ) print(resp) def test_list_bucket_objects(self): import boto3 s3 = boto3.client('s3') bucket = 'alvindaiyan-aigc-testing-playground' key = 'Stable-diffusion/checkpoint/dytest004/8d3a46e6-756e-47a5-a138-66d66f8ffec6' response = s3.list_objects( Bucket=bucket, Prefix=key, ) print(response) for obj in response['Contents']: print(obj['Key'].replace(f'{key}/', "")) def test_timestamp(self): import datetime timestamp = datetime.datetime.now().timestamp() print(timestamp) print(type(timestamp)) def test_get_item(self): from models.model_api import ddb_service, model_table resp = ddb_service.get_item(table=model_table, key_values={ "id": "262676e1-9b57-4ff3-a876-4e1de5ff5d25" }) print(resp) def test_presign_urls(self): from common.util import get_s3_presign_urls bucket = 'alvindaiyan-aigc-testing-playground' key = 'test_upload_manual/yan' resp = get_s3_presign_urls(bucket_name=bucket, base_key=key, filenames=["test"]) print(resp) def test_s3_download(self): import boto3 s3 = boto3.client('s3') bucket = 'alvindaiyan-aigc-testing-playground' key = 'test_upload_manual/yan' s3.list_objects_v2() s3.download_file(bucket, key, 'test')