stable-diffusion-aws-extension/middleware_api/lambda/common/util.py

180 lines
5.3 KiB
Python

import json
import logging
import os
from functools import reduce
from typing import Dict
import boto3
from aws_lambda_powertools import Tracer
from libs.comfy_data_types import InferenceResult
tracer = Tracer()
s3 = boto3.client('s3')
s3_resource = boto3.resource('s3')
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
sns_client = boto3.client('sns')
bucket_name = os.environ.get('S3_BUCKET_NAME')
s3_bucket = s3_resource.Bucket(bucket_name)
def get_multi_query_params(event, param_name: str, default=None):
value = default
if 'multiValueQueryStringParameters' in event:
multi_query = event['multiValueQueryStringParameters']
if multi_query and param_name in multi_query and len(multi_query[param_name]) > 0:
value = multi_query[param_name]
return value
def get_query_param(event, param_name: str, default=None):
if 'queryStringParameters' in event:
queries = event['queryStringParameters']
if queries and param_name in queries:
return queries[param_name]
return default
def query_data(data, paths):
value = data
for path in paths:
value = value.get(path)
if not value:
path_string = reduce(lambda x, y: f"{x}.{y}", paths)
raise ValueError(f"Missing {path_string}")
return value
def publish_msg(topic_arn, msg, subject):
sns_client.publish(
TopicArn=topic_arn,
Message=str(msg),
Subject=subject
)
def get_s3_presign_urls(bucket_name, base_key, filenames) -> Dict[str, str]:
return _get_s3_presign_urls(bucket_name, base_key, filenames, expires=3600 * 24 * 7, method='put_object')
def get_s3_get_presign_urls(bucket_name, base_key, filenames) -> Dict[str, str]:
return _get_s3_presign_urls(bucket_name, base_key, filenames, expires=3600 * 24, method='get_object')
def _get_s3_presign_urls(bucket_name, base_key, filenames, expires=3600, method='put_object') -> Dict[str, str]:
presign_url_map = {}
for filename in filenames:
key = f'{base_key}/{filename}'
url = s3.generate_presigned_url(method,
Params={'Bucket': bucket_name,
'Key': key,
},
ExpiresIn=expires)
presign_url_map[filename] = url
return presign_url_map
@tracer.capture_method
def generate_presign_url(bucket_name, key, expires=3600, method='put_object') -> Dict[str, str]:
return s3.generate_presigned_url(method,
Params={'Bucket': bucket_name,
'Key': key,
},
ExpiresIn=expires)
@tracer.capture_method
def load_json_from_s3(key: str):
key = key.replace(f"s3://{bucket_name}/", '')
response = s3.get_object(Bucket=bucket_name, Key=key)
json_file = response['Body'].read().decode('utf-8')
data = json.loads(json_file)
return data
def save_json_to_file(json_string: str, folder_path: str, file_name: str):
os.makedirs(folder_path, exist_ok=True)
file_path = os.path.join(folder_path, file_name)
with open(file_path, 'w') as file:
file.write(json.dumps(json_string))
return file_path
@tracer.capture_method
def upload_json_to_s3(bucket_name: str, file_key: str, json_data: dict):
'''
Upload the JSON file from the specified bucket and key
'''
try:
s3.put_object(Body=json.dumps(json_data), Bucket=bucket_name, Key=file_key)
logger.info(f"Dictionary uploaded to S3://{bucket_name}/{file_key}")
except Exception as e:
logger.info(f"Error uploading dictionary: {e}")
def split_s3_path(s3_path):
path_parts = s3_path.replace("s3://", "").split("/")
bucket = path_parts.pop(0)
key = "/".join(path_parts)
return bucket, key
@tracer.capture_method
def s3_scan_files(job: InferenceResult):
job.output_files = s3_scan_files_in_patch(job.output_path)
job.temp_files = s3_scan_files_in_patch(job.temp_path)
return job
@tracer.capture_method
def s3_scan_files_in_patch(patch: str):
files = []
prefix = patch.replace(f"s3://{bucket_name}/", '')
for obj in s3_bucket.objects.filter(Prefix=prefix):
file = obj.key.replace(prefix, '')
if file:
files.append(file)
return files
def generate_presigned_url_for_key(key, expiration=3600):
key = key.replace(f"s3://{bucket_name}/", '')
return s3.generate_presigned_url(
'get_object',
Params={'Bucket': bucket_name, 'Key': key},
ExpiresIn=expiration
)
@tracer.capture_method
def generate_presigned_url_for_keys(prefix, keys, expiration=3600):
prefix = prefix.replace(f"s3://{bucket_name}/", '')
new_list = []
for key in keys:
new_list.append(generate_presigned_url_for_key(f"{prefix}{key}", expiration))
return new_list
@tracer.capture_method
def generate_presigned_url_for_job(job):
if 'output_path' in job and 'output_files' in job:
job['output_files'] = generate_presigned_url_for_keys(job['output_path'], job['output_files'])
if 'temp_path' in job and 'temp_files' in job:
job['temp_files'] = generate_presigned_url_for_keys(job['temp_path'], job['temp_files'])
return job