323 lines
9.7 KiB
Python
323 lines
9.7 KiB
Python
import base64
|
|
import json
|
|
import logging
|
|
import os
|
|
|
|
import boto3
|
|
from aws_lambda_powertools import Tracer
|
|
from boto3.dynamodb.conditions import Key
|
|
from botocore.exceptions import ClientError
|
|
|
|
from common.ddb_service.client import DynamoDbUtilsService
|
|
from common.excepts import ForbiddenException, UnauthorizedException, NotFoundException, BadRequestException
|
|
from common.response import unauthorized, forbidden, not_found, bad_request
|
|
from libs.data_types import PARTITION_KEYS, User, Role, Endpoint, Workflow
|
|
|
|
tracer = Tracer()
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
|
|
|
|
user_table = os.environ.get('MULTI_USER_TABLE')
|
|
s3_bucket_name = os.environ.get('S3_BUCKET_NAME')
|
|
ddb_service = DynamoDbUtilsService(logger=logger)
|
|
|
|
encode_type = "utf-8"
|
|
s3 = boto3.client('s3')
|
|
ddb = boto3.resource('dynamodb')
|
|
endpoint_table = ddb.Table(os.environ.get('ENDPOINT_TABLE_NAME'))
|
|
dynamodb = boto3.client('dynamodb')
|
|
|
|
|
|
@tracer.capture_method
|
|
def check_file_exists(key):
|
|
try:
|
|
s3.head_object(Bucket=s3_bucket_name, Key=key)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(e, exc_info=True)
|
|
if e.response['Error']['Code'] == '404':
|
|
return False
|
|
else:
|
|
raise e
|
|
|
|
|
|
@tracer.capture_method
|
|
def update_table_by_pk(table: str, pk_name: str, pk_value: str, key: str, value):
|
|
logger.info(f"Update {table} with {pk_name}: {id}, key: {key}, value: {value}")
|
|
try:
|
|
ddb_client = boto3.resource('dynamodb')
|
|
ddb_table = ddb_client.Table(table)
|
|
ddb_table.update_item(
|
|
Key={
|
|
pk_name: pk_value,
|
|
},
|
|
UpdateExpression=f"set #k = :v",
|
|
ExpressionAttributeNames={'#pk_name': pk_name, '#k': key},
|
|
ExpressionAttributeValues={':v': value},
|
|
ConditionExpression=f"attribute_exists(#pk_name)",
|
|
ReturnValues="UPDATED_NEW"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Update {table} error: {e}")
|
|
raise e
|
|
|
|
|
|
def get_endpoint_name_by_workflow_name(name: str, endpoint_type: str = 'async'):
|
|
return f"comfy-{endpoint_type}-{name}"
|
|
|
|
|
|
@tracer.capture_method
|
|
def get_workflow_by_name(workflow_name: str):
|
|
tracer.put_annotation(key="workflow_name", value=workflow_name)
|
|
|
|
table_name = os.environ.get('WORKFLOWS_TABLE')
|
|
|
|
response = dynamodb.get_item(
|
|
TableName=table_name,
|
|
Key={
|
|
'name': {'S': workflow_name}
|
|
}
|
|
)
|
|
logger.info(response)
|
|
|
|
tracer.put_metadata(key="workflow_name", value=response)
|
|
|
|
item = response.get('Item', None)
|
|
|
|
if item is None:
|
|
raise NotFoundException(f'workflow {workflow_name} not found')
|
|
|
|
return Workflow(
|
|
name=item['name']['S'],
|
|
status=item['status']['S'],
|
|
s3_location=item['s3_location']['S'],
|
|
image_uri=item['image_uri']['S'],
|
|
payload_json=item['payload_json']['S'],
|
|
create_time=item['create_time']['S'],
|
|
)
|
|
|
|
|
|
@tracer.capture_method
|
|
def get_endpoint_by_name(endpoint_name: str):
|
|
tracer.put_annotation(key="endpoint_name", value=endpoint_name)
|
|
|
|
scan_kwargs = {
|
|
'IndexName': "endpoint_name-startTime-index",
|
|
'KeyConditionExpression': Key('endpoint_name').eq(endpoint_name),
|
|
}
|
|
|
|
logger.info(scan_kwargs)
|
|
|
|
response = endpoint_table.query(**scan_kwargs)
|
|
|
|
tracer.put_metadata(key="endpoint_name", value=response)
|
|
|
|
items = response.get('Items', [])
|
|
|
|
if len(items) == 0:
|
|
raise NotFoundException(f'endpoint with name {endpoint_name} not found')
|
|
|
|
return Endpoint(**items[0])
|
|
|
|
|
|
def log_json(title, payload: any = None):
|
|
logger.info(f"{title}: ")
|
|
if payload:
|
|
logger.info(json.dumps(payload, default=str))
|
|
|
|
|
|
class KeyEncryptService:
|
|
|
|
def __init__(self, logging_level=logging.INFO):
|
|
self.kms_client = boto3.client('kms')
|
|
self.logger = logging.getLogger('boto3')
|
|
self.logger.setLevel(logging_level)
|
|
|
|
def encrypt(self, key_id: str, text: str) -> bytes:
|
|
"""
|
|
Encrypts text by using the specified key.
|
|
|
|
:param key_id: The ARN or ID of the key to use for encryption.
|
|
:param text: The text need to be encrypted
|
|
:return: The encrypted version of the text.
|
|
"""
|
|
try:
|
|
cipher_text = self.kms_client.encrypt(
|
|
KeyId=key_id, Plaintext=text.encode())['CiphertextBlob']
|
|
except ClientError as err:
|
|
self.logger.error(
|
|
"Couldn't encrypt text. Here's why: %s", err.response['Error']['Message'])
|
|
else:
|
|
self.logger.debug(f"Your ciphertext is: {cipher_text}")
|
|
return cipher_text
|
|
|
|
def decrypt(self, key_id: str, cipher_text: bytes) -> bytes:
|
|
"""
|
|
Decrypts text previously encrypted with a key.
|
|
|
|
:param key_id: The ARN or ID of the key used to decrypt the data.
|
|
:param cipher_text: The encrypted text to decrypt.
|
|
"""
|
|
try:
|
|
text = self.kms_client.decrypt(KeyId=key_id, CiphertextBlob=cipher_text)['Plaintext']
|
|
except ClientError as err:
|
|
self.logger.error("Couldn't decrypt your ciphertext. Here's why: %s",
|
|
err.response['Error']['Message'])
|
|
|
|
else:
|
|
self.logger.debug(f"Your plaintext is {text.decode()}")
|
|
return text
|
|
|
|
|
|
@tracer.capture_method
|
|
def check_user_existence(ddb_service, user_table, username):
|
|
creator = ddb_service.query_items(table=user_table, key_values={
|
|
'kind': PARTITION_KEYS.user,
|
|
'sort_key': username,
|
|
})
|
|
|
|
return not creator or len(creator) == 0
|
|
|
|
|
|
@tracer.capture_method
|
|
def get_user_by_username(ddb_service, user_table, username):
|
|
user_raw = ddb_service.query_items(table=user_table, key_values={
|
|
'kind': PARTITION_KEYS.user,
|
|
'sort_key': username,
|
|
})
|
|
|
|
if not user_raw or len(user_raw) == 0:
|
|
return None
|
|
|
|
return User(**(ddb_service.deserialize(user_raw[0])))
|
|
|
|
|
|
@tracer.capture_method
|
|
def get_user_roles(ddb_service, user_table_name, username):
|
|
tracer.put_annotation(key="username", value=username)
|
|
user = ddb_service.query_items(table=user_table_name, key_values={
|
|
'kind': PARTITION_KEYS.user,
|
|
'sort_key': username,
|
|
})
|
|
if not user or len(user) == 0:
|
|
raise Exception(f'user: "{username}" not exist')
|
|
|
|
user = User(**ddb_service.deserialize(user[0]))
|
|
return user.roles
|
|
|
|
|
|
def response_error(e):
|
|
try:
|
|
logger.error(e, exc_info=True)
|
|
raise e
|
|
except UnauthorizedException as e:
|
|
return unauthorized(message=str(e))
|
|
except ForbiddenException as e:
|
|
return forbidden(message=str(e))
|
|
except NotFoundException as e:
|
|
return not_found(message=str(e))
|
|
except Exception as e:
|
|
return bad_request(message=str(e))
|
|
|
|
|
|
def get_user_name(event: any):
|
|
if 'headers' not in event:
|
|
raise BadRequestException('Not found headers in event')
|
|
|
|
username = None
|
|
|
|
if 'username' in event['headers']:
|
|
username = event['headers']['username']
|
|
elif 'Authorization' in event['headers']:
|
|
# todo compatibility with 1.4.0, will be removed
|
|
authorization = event['headers']['Authorization']
|
|
if authorization:
|
|
username = base64.b16decode(authorization.replace('Bearer ', '').encode(encode_type)).decode(
|
|
encode_type)
|
|
|
|
if not username:
|
|
raise UnauthorizedException("Unauthorized")
|
|
|
|
return username
|
|
|
|
|
|
@tracer.capture_method
|
|
def permissions_check(event: any, permissions: [str]):
|
|
username = get_user_name(event)
|
|
|
|
tracer.put_annotation(key="username", value=username)
|
|
|
|
if not user_table:
|
|
raise Exception("MULTI_USER_TABLE not set")
|
|
|
|
user = ddb_service.query_items(table=user_table, key_values={
|
|
'kind': PARTITION_KEYS.user,
|
|
'sort_key': username,
|
|
})
|
|
|
|
if not user or len(user) == 0:
|
|
raise UnauthorizedException("Unauthorized")
|
|
|
|
user = User(**ddb_service.deserialize(user[0]))
|
|
logger.info(f'user: {user}')
|
|
|
|
roles = ddb_service.scan(table=user_table, filters={
|
|
'kind': PARTITION_KEYS.role,
|
|
'sort_key': user.roles,
|
|
})
|
|
|
|
for role_raw in roles:
|
|
role = Role(**(ddb_service.deserialize(role_raw)))
|
|
logger.info(f'role: {role}')
|
|
for permission in permissions:
|
|
if permission in role.permissions:
|
|
return username
|
|
|
|
raise ForbiddenException(f"User {username} has no permissions: {permissions}")
|
|
|
|
|
|
def check_user_permissions(checkpoint_owners: [str], user_roles: [str], user_name: str) -> bool:
|
|
if not checkpoint_owners or user_name in checkpoint_owners or '*' in checkpoint_owners:
|
|
return True
|
|
|
|
for user_role in user_roles:
|
|
if user_role in checkpoint_owners:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
@tracer.capture_method
|
|
def get_permissions_by_username(ddb_service, user_table, username):
|
|
creator_roles = get_user_roles(ddb_service, user_table, username)
|
|
roles = ddb_service.scan(table=user_table, filters={
|
|
'kind': PARTITION_KEYS.role,
|
|
'sort_key': creator_roles,
|
|
})
|
|
permissions = {}
|
|
for role_raw in roles:
|
|
role = Role(**(ddb_service.deserialize(role_raw)))
|
|
for permission in role.permissions:
|
|
permission_parts = permission.split(':')
|
|
resource = permission_parts[0]
|
|
action = permission_parts[1]
|
|
|
|
if resource not in permissions:
|
|
permissions[resource] = set()
|
|
|
|
permissions[resource].add(action)
|
|
|
|
return permissions
|
|
|
|
|
|
def encode_last_key(last_evaluated_key):
|
|
if not last_evaluated_key:
|
|
return None
|
|
return base64.b64encode(json.dumps(last_evaluated_key).encode(encode_type)).decode(encode_type)
|
|
|
|
|
|
def decode_last_key(last_evaluated_key):
|
|
if not last_evaluated_key:
|
|
return None
|
|
return json.loads(base64.b64decode(last_evaluated_key.encode(encode_type)).decode(encode_type))
|