stable-diffusion-aws-extension/middleware_api/libs/common_tools.py

178 lines
6.2 KiB
Python
Raw Blame History

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import concurrent.futures
import decimal
import json
import logging
import math
import threading
import urllib.request
from datetime import datetime
from datetime import timedelta
from typing import Dict, Any
import boto3
from libs.data_types import MultipartFileReq, CheckPoint
PART_SIZE = 500 * 1024 * 1024
def batch_get_s3_multipart_signed_urls(bucket_name, base_key, filenames: [MultipartFileReq]) -> Dict[str, Any]:
presign_url_map = {}
for f in filenames:
file = MultipartFileReq(**f)
key = f'{base_key}/{file.filename}'
signed_urls = get_s3_multipart_signed_urls(bucket_name, key, parts_number=file.parts_number)
presign_url_map[file.filename] = signed_urls
return presign_url_map
def get_s3_multipart_signed_urls(bucket_name, key, parts_number) -> Any:
s3 = boto3.client('s3')
response = s3.create_multipart_upload(
Bucket=bucket_name,
Key=key,
Expires=datetime.now() + timedelta(seconds=3600 * 24 * 7)
)
upload_id = response['UploadId']
presign_urls = []
for i in range(1, parts_number + 1):
presign_url = s3.generate_presigned_url(
ClientMethod='upload_part',
Params={
'Bucket': bucket_name,
'Key': key,
'UploadId': upload_id,
'PartNumber': i
}
)
presign_urls.append(presign_url)
return {
's3_signed_urls': presign_urls,
'upload_id': upload_id,
'bucket': bucket_name,
'key': key,
}
def get_base_model_s3_key(_type: str, name: str, request_id: str) -> str:
return f'{_type}/model/{name}/{request_id}'
def get_base_checkpoint_s3_key(_type: str, name: str, request_id: str) -> str:
return f'{_type}/checkpoint/{name}/{request_id}'
def complete_multipart_upload(ckpt: CheckPoint, filename_etag):
s3 = boto3.client('s3')
if 'multipart_upload' in ckpt.params:
multipart = ckpt.params['multipart_upload']
for filename, val in multipart.items():
# todo: can add s3 MD5 check here to see if file is upload properly
if filename in filename_etag:
filename_etag[filename].sort(key=lambda x: x['PartNumber'])
response = s3.complete_multipart_upload(
Bucket=val['bucket'],
Key=val['key'],
MultipartUpload={'Parts': filename_etag[filename]},
UploadId=val['upload_id']
)
print(f'complete upload multipart response {response}')
response = s3.abort_multipart_upload(
Bucket=val['bucket'],
Key=val['key'],
UploadId=val['upload_id']
)
print(f'abort upload multipart response {response}')
def split_s3_path(s3_path):
path_parts = s3_path.replace("s3://", "").split("/")
bucket = path_parts.pop(0)
key = "/".join(path_parts)
return bucket, key
def upload_part_file(s3, bucket, key, part_number, upload_id, part_data):
try:
response = s3.upload_part(
Bucket=bucket,
Key=key,
PartNumber=part_number,
UploadId=upload_id,
Body=part_data
)
return {'PartNumber': part_number, 'ETag': response['ETag']}
except Exception as e:
logging.error(f"Upload of part {part_number} failed: {str(e)}")
return {'PartNumber': part_number, 'ETag': None}
# Args:
# url (str): model source file url,eg:eghttps://civitai.com/api/download/models/xxxx or https://huggingface.co/stabilityai/stable-diffusion-xxxx/resolve/main/xxxx.safetensors
# bucket_name(str): bucket name
# s3_key(str):s3 key
# Returns:
# upload_id: s3 uploadId
# key:s3 key
# bucket: bucket name
def multipart_upload_from_url(url, bucket_name, s3_key):
s3 = boto3.client('s3')
logging.info(f"start multipart_upload_from_url:{url}, {s3_key}")
lock = threading.Lock()
try:
with urllib.request.urlopen(url) as response:
# 获取文件总大小
total_size = int(response.info().get('Content-Length'))
part_count = math.ceil(total_size / PART_SIZE)
upload_id = s3.create_multipart_upload(Bucket=bucket_name, Key=s3_key,
Expires=datetime.now() + timedelta(seconds=3600 * 24 * 7))[
'UploadId']
logging.info(
f"multipart_upload_from_url: total_size:{total_size}, part_count:{part_count} upload_id:{upload_id}")
parts = []
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for part_number in range(1, part_count + 1):
with lock:
start = (part_number - 1) * PART_SIZE
end = min(part_number * PART_SIZE, total_size)
part_data = response.read(end - start)
futures.append(
executor.submit(upload_part_file, s3, bucket_name, s3_key, part_number, upload_id,
part_data))
for future in concurrent.futures.as_completed(futures):
parts.append(future.result())
parts.sort(key=lambda part: part['PartNumber'])
# 完成Multipart上传
s3.complete_multipart_upload(
Bucket=bucket_name,
Key=s3_key,
UploadId=upload_id,
MultipartUpload={'Parts': parts}
)
logging.info("Multipart upload completed!")
return {
"uploadId": upload_id,
"key": s3_key,
"bucket": bucket_name,
}
except Exception as e:
logging.error(f"Multipart upload failed: {str(e)}")
return None
class DecimalEncoder(json.JSONEncoder):
def default(self, obj):
# if passed in object is instance of Decimal
# convert it to a string
if isinstance(obj, decimal.Decimal):
return str(obj)
# otherwise use the default behavior
return json.JSONEncoder.default(self, obj)