stable-diffusion-aws-extension/middleware_api/comfy/prepare.py

96 lines
3.5 KiB
Python

import json
import logging
import os
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from aws_lambda_powertools import Tracer
from common.ddb_service.client import DynamoDbUtilsService
from common.response import ok
from libs.comfy_data_types import ComfySyncTable
from libs.utils import response_error
tracer = Tracer()
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
region = os.environ.get('AWS_REGION')
bucket_name = os.environ.get('S3_BUCKET_NAME')
inference_monitor_table = os.environ.get('INSTANCE_MONITOR_TABLE')
sync_table = os.environ.get('SYNC_TABLE')
endpoint_table = os.environ.get('ENDPOINT_TABLE')
config_table = os.environ.get('CONFIG_TABLE')
ddb_service = DynamoDbUtilsService(logger=logger)
@dataclass
class PrepareEnvEvent:
prepare_id: Optional[str]
endpoint_name: str
need_reboot: bool = False
prepare_type: Optional[str] = 'default'
# notice !!! must not be prefixed with "/"
s3_source_path: Optional[str] = ''
local_target_path: Optional[str] = ''
sync_script: Optional[str] = ''
def get_endpoint_info(endpoint_name: str):
endpoint_raw = ddb_service.scan(endpoint_table, filters={'endpoint_name': endpoint_name})[0] if ddb_service.scan(endpoint_table, filters={'endpoint_name': endpoint_name}) else None
logger.debug(f'endpoint_raw is : {endpoint_raw}')
if not endpoint_raw:
raise Exception(f'sagemaker endpoint with name {endpoint_name} is not found')
endpoint_info = ddb_service.deserialize(endpoint_raw)
logger.debug(f'endpoint_info is : {endpoint_info}')
if endpoint_info is None or len(endpoint_info) == 0:
raise Exception(f'sagemaker endpoint with name {endpoint_name} is not found')
if endpoint_info['endpoint_status'] != 'InService':
raise Exception(f'sagemaker endpoint is not ready with status: {endpoint_info["endpoint_status"]}')
return endpoint_info
def prepare_sagemaker_env(request_id: str, event: PrepareEnvEvent):
endpoint_name = event.endpoint_name
if endpoint_name is None:
raise Exception(f'endpoint name should not be null')
endpoint_info = get_endpoint_info(endpoint_name)
if not endpoint_info:
raise Exception(f'endpoint not found with name {endpoint_name}')
sync_job = ComfySyncTable(
request_id=request_id if event.prepare_id is None else event.prepare_id,
endpoint_name=event.endpoint_name,
endpoint_id=endpoint_info['EndpointDeploymentJobId'],
instance_count=endpoint_info['current_instance_count'],
prepare_type=event.prepare_type,
need_reboot=event.need_reboot,
s3_source_path=event.s3_source_path,
local_target_path=event.local_target_path,
sync_script=event.sync_script,
endpoint_snapshot=endpoint_info,
request_time=int(datetime.now().timestamp()),
request_time_str=datetime.now().isoformat(),
)
save_sync_ddb_resp = ddb_service.put_items(sync_table, entries=sync_job.__dict__)
logger.info(str(save_sync_ddb_resp))
@tracer.capture_lambda_handler
def handler(raw_event, ctx):
try:
logger.info(f"prepare env start... Received event: {raw_event}")
logger.info(f"Received ctx: {ctx}")
request_id = ctx.aws_request_id
event = PrepareEnvEvent(**json.loads(raw_event['body']))
prepare_sagemaker_env(request_id, event)
return ok(data=event.endpoint_name)
except Exception as e:
return response_error(e)