stable-diffusion-aws-extension/middleware_api/lambda/common/ddb_service/client.py

316 lines
12 KiB
Python

import datetime
import enum
import logging
from decimal import Decimal
from typing import Any, List, Dict
import boto3
from botocore.exceptions import ClientError
from common.ddb_service.types_ import GetItemOutput, ScanOutput
class DynamoDbUtilsService:
def __init__(self, logging_level=logging.INFO, logger=None):
self.client = boto3.client('dynamodb')
if logger:
self.logger = logger
else:
self.logger = logging.getLogger('boto3')
self.logger.setLevel(logging_level)
def put_items(self, table: str, entries: Dict[str, Any]) -> Any:
if not table:
raise Exception('table name is required')
try:
if not entries or len(entries) == 0:
return None
ddb_data = self._serialize(entries)
resp = self.client.put_item(
TableName=table,
Item=ddb_data
)
# todo: check if failed raise an error
return resp
except Exception as e:
self.logger.error(f'table {table} put item failed -> {entries}: {e}')
raise Exception(f'table {table} put item failed -> {entries}: {e}')
def batch_put_items(self, table_items: Dict[str, List[Dict[str, Any]]]) -> Any:
try:
if not table_items or len(table_items) == 0:
return None
_items = {}
for table, items in table_items.items():
_items[table] = [{'PutRequest': {'Item': self._serialize(item)}} for item in items]
resp = self.client.batch_write_item(RequestItems=_items)
return resp
except Exception as e:
self.logger.error(f'batch put failed: {e}')
raise Exception(f'batch put failed: {e}')
def update_item(self, table: str, key: Dict[str, Any], field_name: str, value: Any):
search_keys = self._serialize(key)
value = self._convert(value)
try:
self.client.update_item(
TableName=table,
Key=search_keys,
UpdateExpression=f"set {field_name} = :r",
ExpressionAttributeValues={
':r': value
},
ReturnValues="UPDATED_NEW"
)
except ClientError as e:
self.logger.error('keys: %s -> %s: %s', key, field_name, value)
raise Exception(
f'dynamodb update failed with table {table}, key: {key}, field: {field_name}, value: {value}, error: {e}')
def get_item(self, table: str, key_values: Dict[str, Any]) -> Dict[str, Any]:
try:
search_keys = self._serialize(key_values)
resp = self.client.get_item(
TableName=table,
Key=search_keys
)
named_ = GetItemOutput(**resp)
if 'Item' not in named_:
return dict()
res = self.deserialize(named_['Item'])
return res
except ClientError as e:
self.logger.error(f'table {table} keys_values: {key_values}')
raise Exception(f'table {table} get_item failed with keys_values: {key_values}, e: {e}')
def query_latest_item(self, table: str, key_values: Dict[str, Any]) -> Dict[str, Any]:
try:
filter_expressions, expression_values = self._get_ddb_filter(key_values)
resp = self.client.query(
TableName=table,
KeyConditionExpression=filter_expressions,
ExpressionAttributeValues=expression_values,
Limit=1,
)
named_ = ScanOutput(**resp)
if 'Items' not in named_:
return dict()
res = self.deserialize(named_['Items'][0])
return res
except ClientError as e:
self.logger.error(f'table {table} keys_values: {key_values}')
raise Exception(f'table {table} get_item failed with keys_values: {key_values}, e: {e}')
def query_items(self, table: str, key_values: Dict[str, Any], filters: Dict[str, Any] = None, limit: int = None,
last_evaluated_key=None):
try:
key_expressions, expression_values = self._get_ddb_filter(key_values)
if not filters:
if limit:
if last_evaluated_key:
resp = self.client.query(
TableName=table,
KeyConditionExpression=key_expressions,
ExpressionAttributeValues=expression_values,
ExclusiveStartKey=last_evaluated_key,
Limit=limit
)
else:
resp = self.client.query(
TableName=table,
KeyConditionExpression=key_expressions,
ExpressionAttributeValues=expression_values,
Limit=limit
)
else:
scan_resp = self.client.query(
TableName=table,
KeyConditionExpression=key_expressions,
ExpressionAttributeValues=expression_values,
)
resp = scan_resp['Items']
while 'LastEvaluatedKey' in scan_resp:
scan_resp = self.client.query(
TableName=table,
KeyConditionExpression=key_expressions,
ExpressionAttributeValues=expression_values,
ExclusiveStartKey=scan_resp['LastEvaluatedKey']
)
resp.extend(scan_resp['Items'])
# scan the whole table, no LastEvaluatedKey returned
return resp
else:
filter_expressions, filter_expression_values = self._get_ddb_filter(filters=filters)
expression_values.update(filter_expression_values)
resp = self.client.query(
TableName=table,
KeyConditionExpression=key_expressions,
FilterExpression=filter_expressions,
ExpressionAttributeValues=expression_values,
)
named_ = ScanOutput(**resp)
return named_['Items'], named_['LastEvaluatedKey'] if 'LastEvaluatedKey' in named_ else None
except ClientError as e:
self.logger.error(f'table {table} keys_values: {key_values}')
raise Exception(f'table {table} get_item failed with keys_values: {key_values}, e: {e}')
def _get_ddb_filter(self, filters: Dict[str, Any]):
prepare_filter_expressions = []
prefix = ':'
expression_values = {}
for key, val in filters.items():
if isinstance(val, list):
val_keys = ''
i = 0
for v in val:
k = f'{prefix}{key}{str(i)}'
i += 1
val_keys += f'{k}, '
expression_values[k] = self._convert(v)
prepare_filter_expressions.append('{} in ({})'.format(key, val_keys[:len(val_keys) - 2]))
else:
prepare_filter_expressions.append('{} = {}'.format(key, prefix + key))
expression_values[prefix + key] = self._convert(val)
filter_expressions = ' AND '.join(prepare_filter_expressions)
return filter_expressions, expression_values
def scan(self, table: str, filters: Dict[str, Any] = None, last_evaluated_key=None, limit: int = None):
if filters is None or len(filters) == 0:
if limit:
if last_evaluated_key:
resp = self.client.scan(
TableName=table,
ExclusiveStartKey=last_evaluated_key,
Limit=limit
)
else:
resp = self.client.query(
TableName=table,
Limit=limit
)
else:
scan_resp = self.client.scan(
TableName=table,
)
resp = scan_resp['Items']
while 'LastEvaluatedKey' in scan_resp:
scan_resp = self.client.scan(
TableName=table,
ExclusiveStartKey=scan_resp['LastEvaluatedKey']
)
resp.extend(scan_resp['Items'])
return resp
else:
filter_expressions, expression_values = self._get_ddb_filter(filters)
if last_evaluated_key:
resp = self.client.scan(
TableName=table,
FilterExpression=filter_expressions,
ExpressionAttributeValues=expression_values,
ExclusiveStartKey=last_evaluated_key,
Limit=limit
)
elif limit:
resp = self.client.scan(
TableName=table,
FilterExpression=filter_expressions,
ExpressionAttributeValues=expression_values,
Limit=limit
)
else:
scan_resp = self.client.scan(
TableName=table,
FilterExpression=filter_expressions,
ExpressionAttributeValues=expression_values
)
resp = scan_resp['Items']
while 'LastEvaluatedKey' in scan_resp:
scan_resp = self.client.scan(
TableName=table,
FilterExpression=filter_expressions,
ExpressionAttributeValues=expression_values,
ExclusiveStartKey=scan_resp['LastEvaluatedKey']
)
resp.extend(scan_resp['Items'])
return resp
named_ = ScanOutput(**resp)
# FIXME: handle failures
return named_['Items'], named_['LastEvaluatedKey'] if 'LastEvaluatedKey' in named_ else None
def delete_item(self, table: str, keys: dict[str, Any]):
keys = self._serialize(keys)
self.client.delete_item(
TableName=table,
Key=keys
)
# FIXME: handle failures
def close(self):
self.client.close()
@staticmethod
def _serialize(entries: dict[str, Any], prefix: str = '') -> dict[str, Any]:
if not dict:
return {}
result = dict()
for key, val in entries.items():
resolved_val = DynamoDbUtilsService._convert(val)
if resolved_val:
result["{}{}".format(prefix, key)] = resolved_val
return result
@staticmethod
# serializer = boto3.dynamodb.types.TypeSerializer()
# low_level_copy = {k: serializer.serialize(v) for k,v in python_data.items()}
def _convert(val):
if val is None:
return None
if isinstance(val, bytes):
return {'B': val}
if isinstance(val, bool):
return {'BOOL': val}
elif isinstance(val, list):
val_arr = []
for item in val:
val_arr.append(DynamoDbUtilsService._convert(item))
return {'L': val_arr}
elif isinstance(val, float) or isinstance(val, int) or isinstance(val, Decimal):
return {'N': str(val)}
elif isinstance(val, str):
return {'S': str(val)}
elif isinstance(val, enum.Enum):
return {'S': str(val.value)}
elif isinstance(val, dict):
res = {}
for key, val in val.items():
if val is not None:
res[key] = DynamoDbUtilsService._convert(val)
return {'M': res}
elif isinstance(val, datetime.datetime):
return {'S': str(val)}
else:
raise Exception(f'unknown type {val} at type: {type(val)}')
@staticmethod
def deserialize(rows: dict[str, dict[str, Any]]) -> dict[str, Any]:
boto3.resource('dynamodb')
# To go from low-level format to python
deserializer = boto3.dynamodb.types.TypeDeserializer()
python_data = {k: deserializer.deserialize(v) for k, v in rows.items()}
return python_data