102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
import json
|
||
import decimal
|
||
from typing import Dict, Any
|
||
|
||
import boto3
|
||
from datetime import datetime
|
||
from datetime import timedelta
|
||
|
||
from _types import MultipartFileReq, CheckPoint
|
||
|
||
|
||
def batch_get_s3_multipart_signed_urls(bucket_name, base_key, filenames: [MultipartFileReq]) -> Dict[str, Any]:
|
||
presign_url_map = {}
|
||
for f in filenames:
|
||
file = MultipartFileReq(**f)
|
||
key = f'{base_key}/{file.filename}'
|
||
signed_urls = get_s3_multipart_signed_urls(bucket_name, key, parts_number=file.parts_number)
|
||
presign_url_map[file.filename] = signed_urls
|
||
|
||
return presign_url_map
|
||
|
||
|
||
def get_s3_multipart_signed_urls(bucket_name, key, parts_number) -> Any:
|
||
s3 = boto3.client('s3')
|
||
response = s3.create_multipart_upload(
|
||
Bucket=bucket_name,
|
||
Key=key,
|
||
Expires=datetime.now() + timedelta(seconds=3600 * 24 * 7)
|
||
)
|
||
|
||
upload_id = response['UploadId']
|
||
|
||
presign_urls = []
|
||
|
||
for i in range(1, parts_number+1):
|
||
presign_url = s3.generate_presigned_url(
|
||
ClientMethod='upload_part',
|
||
Params={
|
||
'Bucket': bucket_name,
|
||
'Key': key,
|
||
'UploadId': upload_id,
|
||
'PartNumber': i
|
||
}
|
||
)
|
||
presign_urls.append(presign_url)
|
||
return {
|
||
's3_signed_urls': presign_urls,
|
||
'upload_id': upload_id,
|
||
'bucket': bucket_name,
|
||
'key': key,
|
||
}
|
||
|
||
|
||
def get_base_model_s3_key(_type: str, name: str, request_id: str) -> str:
|
||
return f'{_type}/model/{name}/{request_id}'
|
||
|
||
|
||
def get_base_checkpoint_s3_key(_type: str, name: str, request_id: str) -> str:
|
||
return f'{_type}/checkpoint/{name}/{request_id}'
|
||
|
||
|
||
def complete_multipart_upload(ckpt: CheckPoint, filename_etag):
|
||
s3 = boto3.client('s3')
|
||
if 'multipart_upload' in ckpt.params:
|
||
multipart = ckpt.params['multipart_upload']
|
||
for filename, val in multipart.items():
|
||
# todo: can add s3 MD5 check here to see if file is upload properly
|
||
if filename in filename_etag:
|
||
filename_etag[filename].sort(key=lambda x: x['PartNumber'])
|
||
response = s3.complete_multipart_upload(
|
||
Bucket=val['bucket'],
|
||
Key=val['key'],
|
||
MultipartUpload={'Parts': filename_etag[filename]},
|
||
UploadId=val['upload_id']
|
||
)
|
||
print(f'complete upload multipart response {response}')
|
||
response = s3.abort_multipart_upload(
|
||
Bucket=val['bucket'],
|
||
Key=val['key'],
|
||
UploadId=val['upload_id']
|
||
)
|
||
print(f'abort upload multipart response {response}')
|
||
|
||
|
||
def split_s3_path(s3_path):
|
||
path_parts = s3_path.replace("s3://", "").split("/")
|
||
bucket = path_parts.pop(0)
|
||
key = "/".join(path_parts)
|
||
return bucket, key
|
||
|
||
|
||
class DecimalEncoder(json.JSONEncoder):
|
||
def default(self, obj):
|
||
# if passed in object is instance of Decimal
|
||
# convert it to a string
|
||
if isinstance(obj, decimal.Decimal):
|
||
return str(obj)
|
||
|
||
#️ otherwise use the default behavior
|
||
return json.JSONEncoder.default(self, obj)
|
||
|