diff --git a/.gitignore b/.gitignore index 58b47d73..3a1804d6 100644 --- a/.gitignore +++ b/.gitignore @@ -48,5 +48,5 @@ test/aigc_webui_inference_images/.env test/**/.env *.iml .DS_Store -/workshop/ComfyUI/ +/ComfyUI/ /.env diff --git a/Dockerfile.comfy b/Dockerfile.comfy new file mode 100755 index 00000000..b7b56b75 --- /dev/null +++ b/Dockerfile.comfy @@ -0,0 +1,10 @@ +ARG AWS_REGION +FROM 366590864501.dkr.ecr.$AWS_REGION.amazonaws.com/esd-inference:dev + +# TODO BYOC +#RUN apt-get update -y && \ +# apt-get install ffmpeg -y && \ +# rm -rf /var/lib/apt/lists/* \ + +COPY build_scripts/inference/start.sh / +RUN chmod +x /start.sh diff --git a/build_scripts/comfy/comfy_proxy.py b/build_scripts/comfy/comfy_proxy.py new file mode 100755 index 00000000..e4a85361 --- /dev/null +++ b/build_scripts/comfy/comfy_proxy.py @@ -0,0 +1,1486 @@ +import concurrent.futures +import signal +import threading + +import boto3 +import requests +from aiohttp import web + +import folder_paths +import server +from execution import PromptExecutor +import execution +import comfy + +from watchdog.observers import Observer +from watchdog.events import FileSystemEventHandler +import subprocess +from dotenv import load_dotenv + +import fcntl +import hashlib + +import base64 +import datetime +import json +import logging +import os +import sys +import tarfile +import time +import uuid +import gc +from dataclasses import dataclass +from typing import Optional + + +from boto3.dynamodb.conditions import Key + + +DISABLE_AWS_PROXY = 'DISABLE_AWS_PROXY' + +sync_msg_list = [] + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +is_on_sagemaker = os.getenv('ON_SAGEMAKER') == 'true' +is_on_ec2 = os.getenv('ON_EC2') == 'true' + +if is_on_ec2: + env_path = '/etc/environment' + + if 'ENV_FILE_PATH' in os.environ and os.environ.get('ENV_FILE_PATH'): + env_path = os.environ.get('ENV_FILE_PATH') + + load_dotenv('/etc/environment') + logger.info(f"env_path{env_path}") + + env_keys = ['ENV_FILE_PATH', 'COMFY_INPUT_PATH', 'COMFY_MODEL_PATH', 'COMFY_NODE_PATH', 'COMFY_API_URL', + 'COMFY_API_TOKEN', 'COMFY_ENDPOINT', 'COMFY_NEED_SYNC', 'COMFY_NEED_PREPARE', 'COMFY_BUCKET_NAME', + 'MAX_WAIT_TIME', 'MSG_MAX_WAIT_TIME', 'THREAD_MAX_WAIT_TIME', DISABLE_AWS_PROXY, 'DISABLE_AUTO_SYNC'] + + for item in os.environ.keys(): + if item in env_keys: + logger.info(f'evn key: {item} {os.environ.get(item)}') + + DIR3 = "input" + DIR1 = "models" + DIR2 = "custom_nodes" + + if 'COMFY_INPUT_PATH' in os.environ and os.environ.get('COMFY_INPUT_PATH'): + DIR3 = os.environ.get('COMFY_INPUT_PATH') + if 'COMFY_MODEL_PATH' in os.environ and os.environ.get('COMFY_MODEL_PATH'): + DIR1 = os.environ.get('COMFY_MODEL_PATH') + if 'COMFY_NODE_PATH' in os.environ and os.environ.get('COMFY_NODE_PATH'): + DIR2 = os.environ.get('COMFY_NODE_PATH') + + + api_url = os.environ.get('COMFY_API_URL') + api_token = os.environ.get('COMFY_API_TOKEN') + comfy_endpoint = os.environ.get('COMFY_ENDPOINT', 'comfy-real-time-comfy') + comfy_need_sync = os.environ.get('COMFY_NEED_SYNC', True) + comfy_need_prepare = os.environ.get('COMFY_NEED_PREPARE', False) + bucket_name = os.environ.get('COMFY_BUCKET_NAME') + thread_max_wait_time = os.environ.get('THREAD_MAX_WAIT_TIME', 60) + max_wait_time = os.environ.get('MAX_WAIT_TIME', 86400) + msg_max_wait_time = os.environ.get('MSG_MAX_WAIT_TIME', 86400) + is_master_process = os.getenv('MASTER_PROCESS') == 'true' + no_need_sync_files = ['.autosave', '.cache', '.autosave1', '~', '.swp'] + + need_resend_msg_result = [] + PREPARE_ID = 'default' + # additional + PREPARE_MODE = 'additional' + + if not api_url: + raise ValueError("API_URL environment variables must be set.") + + if not api_token: + raise ValueError("API_TOKEN environment variables must be set.") + + if not comfy_endpoint: + raise ValueError("COMFY_ENDPOINT environment variables must be set.") + + headers = {"x-api-key": api_token, "Content-Type": "application/json"} + + + def save_images_locally(response_json, local_folder): + try: + data = response_json.get("data", {}) + prompt_id = data.get("prompt_id") + image_video_data = data.get("image_video_data", {}) + + if not prompt_id or not image_video_data: + logger.info("Missing prompt_id or image_video_data in the response.") + return + + folder_path = os.path.join(local_folder, prompt_id) + os.makedirs(folder_path, exist_ok=True) + + for image_name, image_url in image_video_data.items(): + image_response = requests.get(image_url) + if image_response.status_code == 200: + image_path = os.path.join(folder_path, image_name) + with open(image_path, "wb") as image_file: + image_file.write(image_response.content) + logger.info(f"Image '{image_name}' saved to {image_path}") + else: + logger.info( + f"Failed to download image '{image_name}' from {image_url}. Status code: {image_response.status_code}") + + except Exception as e: + logger.info(f"Error saving images locally: {e}") + + + def calculate_file_hash(file_path): + # 创建一个哈希对象 + hasher = hashlib.sha256() + # 打开文件并逐块更新哈希对象 + with open(file_path, 'rb') as file: + buffer = file.read(65536) # 64KB 的缓冲区大小 + while len(buffer) > 0: + hasher.update(buffer) + buffer = file.read(65536) + # 返回哈希值的十六进制表示 + return hasher.hexdigest() + + + def save_files(prefix, execute, key, target_dir, need_prefix): + if key in execute['data']: + temp_files = execute['data'][key] + for url in temp_files: + loca_file = get_file_name(url) + response = requests.get(url) + # if target_dir not exists, create it + if not os.path.exists(target_dir): + os.makedirs(target_dir) + logger.info(f"Saving file {loca_file} to {target_dir}") + if loca_file.endswith("output_images_will_be_put_here"): + continue + if need_prefix: + with open(f"./{target_dir}/{prefix}_{loca_file}", 'wb') as f: + f.write(response.content) + # current override exist + with open(f"./{target_dir}/{loca_file}", 'wb') as f: + f.write(response.content) + else: + with open(f"./{target_dir}/{loca_file}", 'wb') as f: + f.write(response.content) + + + def get_file_name(url: str): + file_name = url.split('/')[-1] + file_name = file_name.split('?')[0] + return file_name + + + def send_service_msg(server_use, msg): + event = msg.get('event') + data = msg.get('data') + sid = msg.get('sid') if 'sid' in msg else None + server_use.send_sync(event, data, sid) + + + def handle_sync_messages(server_use, msg_array): + already_synced = False + global sync_msg_list + for msg in msg_array: + for item_msg in msg: + event = item_msg.get('event') + data = item_msg.get('data') + sid = item_msg.get('sid') if 'sid' in item_msg else None + if data in sync_msg_list: + continue + sync_msg_list.append(data) + if event == 'finish': + already_synced = True + elif event == 'executed': + global need_resend_msg_result + need_resend_msg_result.append(msg) + server_use.send_sync(event, data, sid) + + return already_synced + + + def execute_proxy(func): + def wrapper(*args, **kwargs): + if 'True' == os.environ.get(DISABLE_AWS_PROXY): + logger.info("disabled aws proxy, use local") + return func(*args, **kwargs) + logger.info(f"enable aws proxy, use aws {comfy_endpoint}") + executor = args[0] + server_use = executor.server + prompt = args[1] + prompt_id = args[2] + extra_data = args[3] + + payload = { + "number": str(server.PromptServer.instance.number), + "prompt": prompt, + "prompt_id": prompt_id, + "extra_data": extra_data, + "endpoint_name": comfy_endpoint, + "need_prepare": comfy_need_prepare, + "need_sync": comfy_need_sync, + "multi_async": False + } + + def send_post_request(url, params): + logger.debug(f"sending post request {url} , params {params}") + get_response = requests.post(url, json=params, headers=headers) + return get_response + + def send_get_request(url): + get_response = requests.get(url, headers=headers) + return get_response + + def check_if_sync_is_already(url): + get_response = send_get_request(url) + prepare_response = get_response.json() + if (prepare_response['statusCode'] == 200 and 'data' in prepare_response and prepare_response['data'] + and prepare_response['data']['prepareSuccess']): + logger.info(f"sync available") + return True + else: + logger.info(f"no sync available for {url} response {prepare_response}") + return False + + def send_error_msg(executor, prompt_id, msg): + mes = { + "prompt_id": prompt_id, + "node_id": "", + "node_type": "on cloud", + "executed": [], + "exception_message": msg, + "exception_type": "", + "traceback": [], + "current_inputs": "", + "current_outputs": "", + } + executor.add_message("execution_error", mes, broadcast=True) + + logger.debug(f"payload is: {payload}") + is_synced = check_if_sync_is_already(f"{api_url}/prepare/{comfy_endpoint}") + if not is_synced: + logger.debug(f"is_synced is {is_synced} stop cloud prompt") + send_error_msg(executor, prompt_id, "Your local environment has not compleated to synchronized on cloud already. Please wait for a moment or click the 'Synchronize' button .") + return + + with concurrent.futures.ThreadPoolExecutor() as executorThread: + execute_future = executorThread.submit(send_post_request, f"{api_url}/executes", payload) + + save_already = False + if comfy_need_sync: + msg_future = executorThread.submit(send_get_request, + f"{api_url}/sync/{prompt_id}") + done, _ = concurrent.futures.wait([execute_future, msg_future], + return_when=concurrent.futures.ALL_COMPLETED) + already_synced = False + global sync_msg_list + sync_msg_list = [] + for future in done: + if future == msg_future: + msg_response = future.result() + logger.info(f"get syc msg: {msg_response.json()}") + if msg_response.status_code == 200: + if 'data' not in msg_response.json() or not msg_response.json().get("data"): + logger.error("there is no response from sync msg by thread ") + time.sleep(1) + else: + logger.debug(msg_response.json()) + already_synced = handle_sync_messages(server_use, msg_response.json().get("data")) + elif future == execute_future: + execute_resp = future.result() + logger.info(f"get execute status: {execute_resp.status_code}") + if execute_resp.status_code == 200 or execute_resp.status_code == 201 or execute_resp.status_code == 202: + i = thread_max_wait_time + while i > 0: + images_response = send_get_request(f"{api_url}/executes/{prompt_id}") + response = images_response.json() + logger.info(f"get execute images: {images_response.status_code}") + if images_response.status_code == 404: + logger.info("no images found already ,waiting sagemaker thread result .....") + time.sleep(3) + i = i - 2 + elif response['data']['status'] == 'failed': + logger.error(f"there is no response on sagemaker from execute thread result !!!!!!!! ") + # send_error_msg(executor, prompt_id, + # f"There may be some errors when valid and execute the prompt on the cloud. Please check the SageMaker logs. error info: {response['data']['message']}") + # no need to send msg anymore + already_synced = True + break + elif response['data']['status'] != 'Completed' and response['data']['status'] != 'success': + logger.info(f"no images found already ,waiting sagemaker thread result, current status is {response['data']['status']}") + time.sleep(2) + i = i - 1 + elif 'data' not in response or not response['data'] or 'status' not in response['data'] or not response['data']['status']: + logger.error(f"there is no response from execute thread result !!!!!!!! {response}") + # no need to send msg anymore + already_synced = True + # send_error_msg(executor, prompt_id,"There may be some errors when executing the prompt on cloud. No images or videos generated.") + break + else: + if ('temp_files' in images_response.json()['data'] and len( + images_response.json()['data']['temp_files']) > 0) or (( + 'output_files' in images_response.json()['data'] and len( + images_response.json()['data']['output_files']) > 0)): + save_files(prompt_id, images_response.json(), 'temp_files', 'temp', False) + save_files(prompt_id, images_response.json(), 'output_files', 'output', True) + else: + send_error_msg(executor, prompt_id, + "There may be some errors when executing the prompt on the cloud. Please check the SageMaker logs.") + # no need to send msg anymore + already_synced = True + logger.debug(images_response.json()) + save_already = True + break + else: + logger.error(f"get execute error: {execute_resp}") + # send_error_msg(executor, prompt_id, "Please valid your prompt and try again.") + # send_error_msg(executor, prompt_id, + # f"There may be some errors when valid and execute the prompt on the cloud. Please check the SageMaker logs. error info: {response['data']['message']}") + # no need to send msg anymore + already_synced = True + break + logger.debug(execute_resp.json()) + + m = msg_max_wait_time + while not already_synced: + msg_response = send_get_request(f"{api_url}/sync/{prompt_id}") + # logger.info(msg_response.json()) + if msg_response.status_code == 200: + if m <= 0: + logger.error("there is no response from sync msg by timeout") + already_synced = True + elif 'data' not in msg_response.json() or not msg_response.json().get("data"): + logger.error("there is no response from sync msg") + time.sleep(1) + m = m - 1 + else: + logger.debug(msg_response.json()) + already_synced = handle_sync_messages(server_use, msg_response.json().get("data")) + logger.info(f"already_synced is :{already_synced}") + time.sleep(1) + m = m - 1 + logger.info(f"check if images are already synced {save_already}") + + if not save_already: + logger.info("check if images are not already synced, please wait") + execute_resp = execute_future.result() + logger.debug(f"execute result :{execute_resp.json()}") + if execute_resp.status_code == 200 or execute_resp.status_code == 201 or execute_resp.status_code == 202: + i = max_wait_time + while i > 0: + images_response = send_get_request(f"{api_url}/executes/{prompt_id}") + response = images_response.json() + logger.debug(response) + if images_response.status_code == 404: + logger.info(f"{i} no images found already ,waiting sagemaker result .....") + i = i - 2 + time.sleep(3) + elif response['data']['status'] == 'failed': + logger.error( + f"there is no response on sagemaker from execute result !!!!!!!! ") + if 'message' in response['data'] and response['data']['message']: + send_error_msg(executor, prompt_id, + f"There may be some errors when valid or execute the prompt on the cloud. Please check the SageMaker logs. errors: {response['data']['message']}") + break + else: + logger.error(f"valid error on sagemaker :{response['data']}") + send_error_msg(executor, prompt_id, + f"There may be some errors when valid or execute the prompt on the cloud. errors") + break + elif response['data']['status'] != 'Completed' and response['data']['status'] != 'success': + logger.info(f"{i} images not already ,waiting sagemaker result .....{response['data']['status'] }") + i = i - 1 + time.sleep(3) + elif 'data' not in response or not response['data'] or 'status' not in response['data'] or not response['data']['status']: + logger.info(f"{i} there is no response from sync executes {response}") + send_error_msg(executor, prompt_id, f"There may be some errors when executing the prompt on the cloud. No images or videos generated. {response['message']}") + break + elif response['data']['status'] == 'Completed' or response['data']['status'] == 'success': + if ('temp_files' in images_response.json()['data'] and len(images_response.json()['data']['temp_files']) > 0) or (('output_files' in images_response.json()['data'] and len(images_response.json()['data']['output_files']) > 0)): + save_files(prompt_id, images_response.json(), 'temp_files', 'temp', False) + save_files(prompt_id, images_response.json(), 'output_files', 'output', True) + break + else: + send_error_msg(executor, prompt_id, + "There may be some errors when executing the prompt on the cloud. Please check the SageMaker logs.") + break + else: + # logger.info( + # f"{i} images not already other,waiting sagemaker result .....{response}") + # i = i - 1 + # time.sleep(3) + send_error_msg(executor, prompt_id, + "You have some errors when execute prompt on cloud . Please check your sagemaker logs.") + break + else: + logger.error(f"get execute error: {execute_resp}") + send_error_msg(executor, prompt_id, "Please valid your prompt and try again.") + logger.info("execute finished") + executorThread.shutdown() + + return wrapper + + + PromptExecutor.execute = execute_proxy(PromptExecutor.execute) + + + def send_sync_proxy(func): + def wrapper(*args, **kwargs): + logger.info(f"Sending sync request----- {args}") + return func(*args, **kwargs) + return wrapper + + + server.PromptServer.send_sync = send_sync_proxy(server.PromptServer.send_sync) + + + def compress_and_upload(folder_path, prepare_version): + for subdir in next(os.walk(folder_path))[1]: + subdir_path = os.path.join(folder_path, subdir) + tar_filename = f"{subdir}.tar.gz" + logger.info(f"Compressing the {tar_filename}") + + # 创建 tar 压缩文件 + with tarfile.open(tar_filename, "w:gz") as tar: + tar.add(subdir_path, arcname=os.path.basename(subdir_path)) + s5cmd_syn_node_command = f's5cmd --log=error cp {tar_filename} "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/custom_nodes/"' + logger.info(s5cmd_syn_node_command) + os.system(s5cmd_syn_node_command) + logger.info(f"rm {tar_filename}") + os.remove(tar_filename) + + # for root, dirs, files in os.walk(folder_path): + # for directory in dirs: + # dir_path = os.path.join(root, directory) + # logger.info(f"Compressing the {dir_path}") + # tar_filename = f"{directory}.tar.gz" + # tar_filepath = os.path.join(root, tar_filename) + # with tarfile.open(tar_filepath, "w:gz") as tar: + # tar.add(dir_path, arcname=os.path.basename(dir_path)) + # s5cmd_syn_node_command = f's5cmd --log=error cp {tar_filepath} "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"' + # logger.info(s5cmd_syn_node_command) + # os.system(s5cmd_syn_node_command) + # logger.info(f"rm {tar_filepath}") + # os.remove(tar_filepath) + + + def sync_default_files(): + try: + timestamp = str(int(time.time() * 1000)) + prepare_version = PREPARE_ID if PREPARE_MODE == 'additional' else timestamp + need_prepare = True + prepare_type = 'default' + need_reboot = True + logger.info(f" sync custom nodes files") + # s5cmd_syn_node_command = f's5cmd --log=error sync --delete=true --exclude="*comfy_local_proxy.py" {DIR2}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"' + # logger.info(f"sync custom_nodes files start {s5cmd_syn_node_command}") + # os.system(s5cmd_syn_node_command) + compress_and_upload(f"{DIR2}", prepare_version) + logger.info(f" sync input files") + s5cmd_syn_input_command = f's5cmd --log=error sync --delete=true {DIR3}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/input/"' + logger.info(f"sync input files start {s5cmd_syn_input_command}") + os.system(s5cmd_syn_input_command) + logger.info(f" sync models files") + s5cmd_syn_model_command = f's5cmd --log=error sync --delete=true {DIR1}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/models/"' + logger.info(f"sync models files start {s5cmd_syn_model_command}") + os.system(s5cmd_syn_model_command) + logger.info(f"Files changed in:: {need_prepare} {DIR2} {DIR1} {DIR3}") + url = api_url + "prepare" + logger.info(f"URL:{url}") + data = {"endpoint_name": comfy_endpoint, "need_reboot": need_reboot, "prepare_id": prepare_version, + "prepare_type": prepare_type} + logger.info(f"prepare params Data: {json.dumps(data, indent=4)}") + result = subprocess.run(["curl", "--location", "--request", "POST", url, "--header", + f"x-api-key: {api_token}", "--data-raw", json.dumps(data)], + capture_output=True, text=True) + logger.info(result.stdout) + return result.stdout + except Exception as e: + logger.info(f"sync_files error {e}") + return None + + + def sync_files(filepath, is_folder, is_auto): + try: + directory = os.path.dirname(filepath) + logger.info(f"Directory changed in: {directory} {filepath}") + if not directory: + logger.info("root path no need to sync files by duplicate opt") + return None + logger.info(f"Files changed in: {filepath}") + timestamp = str(int(time.time() * 1000)) + need_prepare = False + prepare_type = 'default' + need_reboot = False + for ignore_item in no_need_sync_files: + if filepath.endswith(ignore_item): + logger.info(f"no need to sync files by ignore files {filepath} ends by {ignore_item}") + return None + prepare_version = PREPARE_ID if PREPARE_MODE == 'additional' else timestamp + if (str(directory).endswith(f"{DIR2}" if DIR2.startswith("/") else f"/{DIR2}") + or str(filepath) == DIR2 or str(filepath) == f'./{DIR2}' or f"{DIR2}/" in filepath): + logger.info(f" sync custom nodes files: {filepath}") + s5cmd_syn_node_command = f's5cmd --log=error sync --delete=true --exclude="*comfy_local_proxy.py" {DIR2}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/custom_nodes/"' + # s5cmd_syn_node_command = f'aws s3 sync {DIR2}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"' + # s5cmd_syn_node_command = f's5cmd sync {DIR2}/* "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"' + + # custom_node文件夹有变化 稍后再同步 + if is_auto and not is_folder_unlocked(directory): + logger.info("sync custom_nodes files is changing ,waiting.... ") + return None + logger.info("sync custom_nodes files start") + logger.info(s5cmd_syn_node_command) + os.system(s5cmd_syn_node_command) + need_prepare = True + need_reboot = True + prepare_type = 'nodes' + elif (str(directory).endswith(f"{DIR3}" if DIR3.startswith("/") else f"/{DIR3}") + or str(filepath) == DIR3 or str(filepath) == f'./{DIR3}' or f"{DIR3}/" in filepath): + logger.info(f" sync input files: {filepath}") + s5cmd_syn_input_command = f's5cmd --log=error sync --delete=true {DIR3}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/input/"' + + # 判断文件写完后再同步 + if is_auto: + if bool(is_folder): + can_sync = is_folder_unlocked(filepath) + else: + can_sync = is_file_unlocked(filepath) + if not can_sync: + logger.info("sync input files is changing ,waiting.... ") + return None + logger.info("sync input files start") + logger.info(s5cmd_syn_input_command) + os.system(s5cmd_syn_input_command) + need_prepare = True + prepare_type = 'inputs' + elif (str(directory).endswith(f"{DIR1}" if DIR1.startswith("/") else f"/{DIR1}") + or str(filepath) == DIR1 or str(filepath) == f'./{DIR1}' or f"{DIR1}/" in filepath): + logger.info(f" sync models files: {filepath}") + s5cmd_syn_model_command = f's5cmd --log=error sync --delete=true {DIR1}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/models/"' + + # 判断文件写完后再同步 + if is_auto: + if bool(is_folder): + can_sync = is_folder_unlocked(filepath) + else: + can_sync = is_file_unlocked(filepath) + # logger.info(f'is folder {directory} {is_folder} can_sync {can_sync}') + if not can_sync: + logger.info("sync input models is changing ,waiting.... ") + return None + + logger.info("sync models files start") + logger.info(s5cmd_syn_model_command) + os.system(s5cmd_syn_model_command) + need_prepare = True + prepare_type = 'models' + logger.info(f"Files changed in:: {need_prepare} {str(directory)} {DIR2} {DIR1} {DIR3}") + if need_prepare: + url = api_url + "prepare" + logger.info(f"URL:{url}") + data = {"endpoint_name": comfy_endpoint, "need_reboot": need_reboot, "prepare_id": prepare_version, + "prepare_type": prepare_type} + logger.info(f"prepare params Data: {json.dumps(data, indent=4)}") + result = subprocess.run(["curl", "--location", "--request", "POST", url, "--header", + f"x-api-key: {api_token}", "--data-raw", json.dumps(data)], + capture_output=True, text=True) + logger.info(result.stdout) + return result.stdout + return None + except Exception as e: + logger.info(f"sync_files error {e}") + return None + + + def is_folder_unlocked(directory): + # logger.info("check if folder ") + event_handler = MyHandlerWithCheck() + observer = Observer() + observer.schedule(event_handler, directory, recursive=True) + observer.start() + time.sleep(1) + result = False + try: + if event_handler.file_changed: + logger.info(f"folder {directory} is still changing..") + event_handler.file_changed = False + time.sleep(1) + if event_handler.file_changed: + logger.info(f"folder {directory} is still still changing..") + else: + logger.info(f"folder {directory} changing stopped") + result = True + else: + logger.info(f"folder {directory} not stopped") + result = True + except (KeyboardInterrupt, Exception) as e: + logger.info(f"folder {directory} changed exception {e}") + observer.stop() + return result + + + def is_file_unlocked(file_path): + # logger.info("check if file ") + try: + initial_size = os.path.getsize(file_path) + initial_mtime = os.path.getmtime(file_path) + time.sleep(1) + + current_size = os.path.getsize(file_path) + current_mtime = os.path.getmtime(file_path) + if current_size != initial_size or current_mtime != initial_mtime: + logger.info(f"unlock file error {file_path} is changing") + return False + + with open(file_path, 'r') as f: + fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) + return True + except (IOError, OSError, Exception) as e: + logger.info(f"unlock file error {file_path} is writing") + logger.error(e) + return False + + + class MyHandlerWithCheck(FileSystemEventHandler): + def __init__(self): + self.file_changed = False + + def on_modified(self, event): + logger.info(f"custom_node folder is changing {event.src_path}") + self.file_changed = True + + def on_deleted(self, event): + logger.info(f"custom_node folder is changing {event.src_path}") + self.file_changed = True + + def on_created(self, event): + logger.info(f"custom_node folder is changing {event.src_path}") + self.file_changed = True + + + class MyHandlerWithSync(FileSystemEventHandler): + def on_modified(self, event): + logger.info(f"{datetime.datetime.now()} files modified ,start to sync {event}") + sync_files(event.src_path, event.is_directory, True) + + def on_created(self, event): + logger.info(f"{datetime.datetime.now()} files added ,start to sync {event}") + sync_files(event.src_path, event.is_directory, True) + + def on_deleted(self, event): + logger.info(f"{datetime.datetime.now()} files deleted ,start to sync {event}") + sync_files(event.src_path, event.is_directory, True) + + + stop_event = threading.Event() + + + def check_and_sync(): + logger.info("check_and_sync start") + event_handler = MyHandlerWithSync() + observer = Observer() + try: + observer.schedule(event_handler, DIR1, recursive=True) + observer.schedule(event_handler, DIR2, recursive=True) + observer.schedule(event_handler, DIR3, recursive=True) + observer.start() + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("sync Shutting down please restart ComfyUI") + observer.stop() + observer.join() + + + def signal_handler(sig, frame): + logger.info("Received termination signal. Exiting...") + stop_event.set() + + + if os.environ.get('DISABLE_AUTO_SYNC') == 'false': + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + check_sync_thread = threading.Thread(target=check_and_sync) + check_sync_thread.start() + + + @server.PromptServer.instance.routes.get("/reboot") + async def restart(self): + logger.info(f"start to reboot {self}") + try: + subprocess.run(["sudo", "reboot"]) + except Exception as e: + logger.info(f"error reboot {e}") + pass + return os.execv(sys.executable, [sys.executable] + sys.argv) + + + @server.PromptServer.instance.routes.get("/check_prepare") + async def check_prepare(self): + logger.info(f"start to check_prepare {self}") + try: + get_response = requests.get(f"{api_url}/prepare/{comfy_endpoint}", headers=headers) + response = get_response.json() + logger.info(f"check sync response is {response}") + if get_response.status_code == 200 and response['data']['prepareSuccess']: + return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True})) + else: + logger.info(f"check sync response is {response} {response['data']['prepareSuccess']}") + return web.Response(status=500, content_type='application/json', body=json.dumps({"result": False})) + except Exception as e: + logger.info(f"error restart {e}") + pass + return os.execv(sys.executable, [sys.executable] + sys.argv) + + + @server.PromptServer.instance.routes.get("/gc") + async def gc(self): + logger.info(f"start to gc {self}") + try: + logger.info(f"gc start: {time.time()}") + server_instance = server.PromptServer.instance + e = execution.PromptExecutor(server_instance) + e.reset() + comfy.model_management.cleanup_models() + gc.collect() + comfy.model_management.soft_empty_cache() + gc_triggered = True + logger.info(f"gc end: {time.time()}") + except Exception as e: + logger.info(f"error restart {e}") + pass + return os.execv(sys.executable, [sys.executable] + sys.argv) + + + @server.PromptServer.instance.routes.get("/restart") + async def restart(self): + logger.info(f"start to restart {self}") + try: + sys.stdout.close_log() + except Exception as e: + logger.info(f"error restart {e}") + pass + return os.execv(sys.executable, [sys.executable] + sys.argv) + + + @server.PromptServer.instance.routes.get("/sync_env") + async def sync_env(request): + logger.info(f"start to sync_env {request}") + try: + result = sync_default_files() + logger.debug(f"sync result is :{result}") + return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True})) + except Exception as e: + logger.info(f"error sync_env {e}") + pass + return web.Response(status=500, content_type='application/json', body=json.dumps({"result": False})) + + + @server.PromptServer.instance.routes.post("/change_env") + async def change_env(request): + logger.info(f"start to change_env {request}") + json_data = await request.json() + if DISABLE_AWS_PROXY in json_data and json_data[DISABLE_AWS_PROXY] is not None: + logger.info(f"origin evn key DISABLE_AWS_PROXY is :{os.environ.get(DISABLE_AWS_PROXY)} {str(json_data[DISABLE_AWS_PROXY])}") + os.environ[DISABLE_AWS_PROXY] = str(json_data[DISABLE_AWS_PROXY]) + logger.info(f"now evn key DISABLE_AWS_PROXY is :{os.environ.get(DISABLE_AWS_PROXY)}") + return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True})) + + + @server.PromptServer.instance.routes.get("/get_env") + async def get_env(request): + env = os.environ.get(DISABLE_AWS_PROXY, 'False') + return web.Response(status=200, content_type='application/json', body=json.dumps({"env": env})) + + + @server.PromptServer.instance.routes.post("/workflows") + async def release_workflow(request): + if not is_master_process: + return web.Response(status=200, content_type='application/json', + body=json.dumps({"result": False, "message": "only master can release workflow"})) + + logger.info(f"start to release workflow {request}") + try: + json_data = await request.json() + if 'name' not in json_data or not json_data['name']: + raise ValueError("name is required") + + workflow_name = json_data['name'] + payload_json = '' + + if 'payload_json' in json_data: + payload_json = json_data['payload_json'] + + if check_file_exists(f"comfy/workflows/{workflow_name}/lock"): + return web.Response(status=200, content_type='application/json', + body=json.dumps({"result": False, "message": "workflow already exists"})) + + start_time = time.time() + + s5cmd_syn_model_command = (f's5cmd sync ' + f'--delete=true ' + f'--exclude="*.log" ' + f'--exclude="*__pycache__*" ' + f'--exclude="*.cache*" ' + f'"/home/ubuntu/ComfyUI/*" ' + f'"s3://{bucket_name}/comfy/workflows/{workflow_name}/"') + logger.info(f"sync models files start {s5cmd_syn_model_command}") + os.system(s5cmd_syn_model_command) + + end_time = time.time() + cost_time = end_time - start_time + data = { + "payload_json": payload_json, + "image_uri": os.getenv('IMAGE_HASH'), + "name": workflow_name, + } + get_response = requests.post(f"{api_url}/workflows", headers=headers, data=json.dumps(data)) + response = get_response.json() + logger.info(f"release workflow response is {response}") + + if get_response.status_code == 200: + os.system(f'echo "lock" > lock && s5cmd sync lock s3://{bucket_name}/comfy/workflows/{workflow_name}/lock') + + return web.Response(status=200, content_type='application/json', + body=json.dumps({"result": True, "message": "success", "cost_time": cost_time})) + except Exception as e: + return web.Response(status=500, content_type='application/json', + body=json.dumps({"result": False, "message": e})) + + + def check_file_exists(key): + try: + s3 = boto3.client('s3') + s3.head_object(Bucket=bucket_name, Key=key) + return True + except Exception as e: + logger.error(e, exc_info=True) + if e.response['Error']['Code'] == '404': + return False + else: + raise e + + + def restore_commands(): + subprocess.run(["sleep", "5"]) + os.system("rm -rf /home/ubuntu/ComfyUI") + subprocess.run(["pkill", "-f", "python3"]) + + + # RestoreEC2EnvironmentToDefault + @server.PromptServer.instance.routes.post("/restore") + async def release_rebuild_workflow(request): + if not is_master_process: + return web.Response(status=200, content_type='application/json', + body=json.dumps({"result": False, "message": "only master can restore comfy"})) + + logger.info(f"start to restore EC2 {request}") + + try: + thread = threading.Thread(target=restore_commands) + thread.start() + return web.Response(status=200, content_type='application/json', + body=json.dumps({"result": True, "message": "comfy will be restored in 5 seconds"})) + except Exception as e: + return web.Response(status=500, content_type='application/json', + body=json.dumps({"result": False, "message": e})) + + +if is_on_sagemaker: + + global need_sync + global prompt_id + global executing + executing = False + + global reboot + reboot = False + + global last_call_time + last_call_time = None + global gc_triggered + gc_triggered = False + + REGION = os.environ.get('AWS_REGION') + BUCKET = os.environ.get('S3_BUCKET_NAME') + QUEUE_URL = os.environ.get('COMFY_QUEUE_URL') + + GEN_INSTANCE_ID = os.environ.get('ENDPOINT_INSTANCE_ID') if 'ENDPOINT_INSTANCE_ID' in os.environ and os.environ.get('ENDPOINT_INSTANCE_ID') else str(uuid.uuid4()) + ENDPOINT_NAME = os.environ.get('ENDPOINT_NAME') + ENDPOINT_ID = os.environ.get('ENDPOINT_ID') + + INSTANCE_MONITOR_TABLE_NAME = os.environ.get('COMFY_INSTANCE_MONITOR_TABLE') + SYNC_TABLE_NAME = os.environ.get('COMFY_SYNC_TABLE') + + dynamodb = boto3.resource('dynamodb', region_name=REGION) + sync_table = dynamodb.Table(SYNC_TABLE_NAME) + instance_monitor_table = dynamodb.Table(INSTANCE_MONITOR_TABLE_NAME) + + logger = logging.getLogger(__name__) + logger.setLevel(os.environ.get('LOG_LEVEL') or logging.INFO) + + ROOT_PATH = '/home/ubuntu/ComfyUI' + sqs_client = boto3.client('sqs', region_name=REGION) + + GC_WAIT_TIME = 1800 + + + def print_env(): + for key, value in os.environ.items(): + logger.info(f"{key}: {value}") + + + @dataclass + class ComfyResponse: + statusCode: int + message: str + body: Optional[dict] + + + def ok(body: dict): + return web.Response(status=200, content_type='application/json', body=json.dumps(body)) + + + def error(body: dict): + # TODO 500 -》200 because of need resp anyway not exception + return web.Response(status=200, content_type='application/json', body=json.dumps(body)) + + + def sen_sqs_msg(message_body, prompt_id_key): + response = sqs_client.send_message( + QueueUrl=QUEUE_URL, + MessageBody=json.dumps(message_body), + MessageGroupId=prompt_id_key + ) + message_id = response['MessageId'] + return message_id + + + def sen_finish_sqs_msg(prompt_id_key): + global need_sync + # logger.info(f"sen_finish_sqs_msg start... {need_sync},{prompt_id_key}") + if need_sync and QUEUE_URL and REGION: + message_body = {'prompt_id': prompt_id_key, 'event': 'finish', 'data': {"node": None, "prompt_id": prompt_id_key}, + 'sid': None} + message_id = sen_sqs_msg(message_body, prompt_id_key) + logger.info(f"finish message sent {message_id}") + + + async def prepare_comfy_env(sync_item: dict): + try: + request_id = sync_item['request_id'] + logger.info(f"prepare_environment start sync_item:{sync_item}") + prepare_type = sync_item['prepare_type'] + rlt = True + if prepare_type in ['default', 'models']: + sync_models_rlt = sync_s3_files_or_folders_to_local(f'{request_id}/models/*', f'{ROOT_PATH}/models', False) + if not sync_models_rlt: + rlt = False + if prepare_type in ['default', 'inputs']: + sync_inputs_rlt = sync_s3_files_or_folders_to_local(f'{request_id}/input/*', f'{ROOT_PATH}/input', False) + if not sync_inputs_rlt: + rlt = False + if prepare_type in ['default', 'nodes']: + sync_nodes_rlt = sync_s3_files_or_folders_to_local(f'{request_id}/custom_nodes/*', + f'{ROOT_PATH}/custom_nodes', True) + if not sync_nodes_rlt: + rlt = False + if prepare_type == 'custom': + sync_source_path = sync_item['s3_source_path'] + local_target_path = sync_item['local_target_path'] + if not sync_source_path or not local_target_path: + logger.info("s3_source_path and local_target_path should not be empty") + else: + sync_rlt = sync_s3_files_or_folders_to_local(sync_source_path, + f'{ROOT_PATH}/{local_target_path}', False) + if not sync_rlt: + rlt = False + elif prepare_type == 'other': + sync_script = sync_item['sync_script'] + logger.info("sync_script") + # sync_script.startswith('s5cmd') 不允许 + try: + if sync_script and (sync_script.startswith("python3 -m pip") or sync_script.startswith("python -m pip") + or sync_script.startswith("pip install") or sync_script.startswith("apt") + or sync_script.startswith("os.environ") or sync_script.startswith("ls") + or sync_script.startswith("env") or sync_script.startswith("source") + or sync_script.startswith("curl") or sync_script.startswith("wget") + or sync_script.startswith("print") or sync_script.startswith("cat") + or sync_script.startswith("sudo chmod") or sync_script.startswith("chmod") + or sync_script.startswith("/home/ubuntu/ComfyUI/venv/bin/python")): + os.system(sync_script) + elif sync_script and (sync_script.startswith("export ") and len(sync_script.split(" ")) > 2): + sync_script_key = sync_script.split(" ")[1] + sync_script_value = sync_script.split(" ")[2] + os.environ[sync_script_key] = sync_script_value + logger.info(os.environ.get(sync_script_key)) + except Exception as e: + logger.error(f"Exception while execute sync_scripts : {sync_script}") + rlt = False + need_reboot = True if ('need_reboot' in sync_item and sync_item['need_reboot'] + and str(sync_item['need_reboot']).lower() == 'true')else False + global reboot + reboot = need_reboot + if need_reboot: + os.environ['NEED_REBOOT'] = 'true' + else: + os.environ['NEED_REBOOT'] = 'false' + logger.info("prepare_environment end") + os.environ['LAST_SYNC_REQUEST_ID'] = sync_item['request_id'] + os.environ['LAST_SYNC_REQUEST_TIME'] = str(sync_item['request_time']) + return rlt + except Exception as e: + return False + + + def sync_s3_files_or_folders_to_local(s3_path, local_path, need_un_tar): + logger.info("sync_s3_models_or_inputs_to_local start") + # s5cmd_command = f'{ROOT_PATH}/tools/s5cmd sync "s3://{bucket_name}/{s3_path}/*" "{local_path}/"' + if need_un_tar: + s5cmd_command = f's5cmd sync "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"' + else: + s5cmd_command = f's5cmd sync --delete=true "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"' + # s5cmd_command = f's5cmd sync --delete=true "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"' + # s5cmd_command = f's5cmd sync "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"' + try: + logger.info(s5cmd_command) + os.system(s5cmd_command) + logger.info(f'Files copied from "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" to "{local_path}/"') + if need_un_tar: + for filename in os.listdir(local_path): + if filename.endswith(".tar.gz"): + tar_filepath = os.path.join(local_path, filename) + # extract_path = os.path.splitext(os.path.splitext(tar_filepath)[0])[0] + # os.makedirs(extract_path, exist_ok=True) + # logger.info(f'Extracting extract_path is {extract_path}') + + with tarfile.open(tar_filepath, "r:gz") as tar: + for member in tar.getmembers(): + tar.extract(member, path=local_path) + os.remove(tar_filepath) + logger.info(f'File {tar_filepath} extracted and removed') + return True + except Exception as e: + logger.info(f"Error executing s5cmd command: {e}") + return False + + + def sync_local_outputs_to_s3(s3_path, local_path): + logger.info("sync_local_outputs_to_s3 start") + s5cmd_command = f's5cmd sync "{local_path}/*" "s3://{BUCKET}/comfy/{s3_path}/" ' + try: + logger.info(s5cmd_command) + os.system(s5cmd_command) + logger.info(f'Files copied local to "s3://{BUCKET}/comfy/{s3_path}/" to "{local_path}/"') + clean_cmd = f'rm -rf {local_path}' + os.system(clean_cmd) + logger.info(f'Files removed from local {local_path}') + except Exception as e: + logger.info(f"Error executing s5cmd command: {e}") + + + def sync_local_outputs_to_base64(local_path): + logger.info("sync_local_outputs_to_base64 start") + try: + result = {} + for root, dirs, files in os.walk(local_path): + for file in files: + file_path = os.path.join(root, file) + with open(file_path, "rb") as f: + file_content = f.read() + base64_content = base64.b64encode(file_content).decode('utf-8') + result[file] = base64_content + clean_cmd = f'rm -rf {local_path}' + os.system(clean_cmd) + logger.info(f'Files removed from local {local_path}') + return result + except Exception as e: + logger.info(f"Error executing s5cmd command: {e}") + return {} + + + @server.PromptServer.instance.routes.post("/execute_proxy") + async def execute_proxy(request): + json_data = await request.json() + if 'out_path' in json_data and json_data['out_path'] is not None: + out_path = json_data['out_path'] + else: + out_path = None + logger.info(f"invocations start json_data:{json_data}") + global need_sync + need_sync = json_data["need_sync"] + global prompt_id + prompt_id = json_data["prompt_id"] + try: + global executing + if executing is True: + resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail", + "message": "the environment is not ready valid[0] is false, need to resync"} + sen_finish_sqs_msg(prompt_id) + return error(resp) + executing = True + logger.info( + f'bucket_name: {BUCKET}, region: {REGION}') + if ('need_prepare' in json_data and json_data['need_prepare'] + and 'prepare_props' in json_data and json_data['prepare_props']): + sync_already = await prepare_comfy_env(json_data['prepare_props']) + if not sync_already: + resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail", + "message": "the environment is not ready with sync"} + executing = False + sen_finish_sqs_msg(prompt_id) + return error(resp) + server_instance = server.PromptServer.instance + if "number" in json_data: + number = float(json_data['number']) + server_instance.number = number + else: + number = server_instance.number + if "front" in json_data: + if json_data['front']: + number = -number + server_instance.number += 1 + valid = execution.validate_prompt(json_data['prompt']) + logger.info(f"Validating prompt result is {valid}") + if not valid[0]: + resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail", + "message": "the environment is not ready valid[0] is false, need to resync"} + executing = False + response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} + sen_finish_sqs_msg(prompt_id) + return error(resp) + # if len(valid) == 4 and len(valid[3]) > 0: + # logger.info(f"Validating prompt error there is something error because of :valid: {valid}") + # resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail", + # "message": f"the valid is error, need to resync or check the workflow :{valid}"} + # executing = False + # return error(resp) + extra_data = {} + client_id = '' + if "extra_data" in json_data: + extra_data = json_data["extra_data"] + if 'client_id' in extra_data and extra_data['client_id']: + client_id = extra_data['client_id'] + if "client_id" in json_data and json_data["client_id"]: + extra_data["client_id"] = json_data["client_id"] + client_id = json_data["client_id"] + + server_instance.client_id = client_id + + prompt_id = json_data['prompt_id'] + server_instance.last_prompt_id = prompt_id + e = execution.PromptExecutor(server_instance) + outputs_to_execute = valid[2] + e.execute(json_data['prompt'], prompt_id, extra_data, outputs_to_execute) + + s3_out_path = f'output/{prompt_id}/{out_path}' if out_path is not None else f'output/{prompt_id}' + s3_temp_path = f'temp/{prompt_id}/{out_path}' if out_path is not None else f'temp/{prompt_id}' + local_out_path = f'{ROOT_PATH}/output/{out_path}' if out_path is not None else f'{ROOT_PATH}/output' + local_temp_path = f'{ROOT_PATH}/temp/{out_path}' if out_path is not None else f'{ROOT_PATH}/temp' + + logger.info(f"s3_out_path is {s3_out_path} and s3_temp_path is {s3_temp_path} and local_out_path is {local_out_path} and local_temp_path is {local_temp_path}") + + sync_local_outputs_to_s3(s3_out_path, local_out_path) + sync_local_outputs_to_s3(s3_temp_path, local_temp_path) + + response_body = { + "prompt_id": prompt_id, + "instance_id": GEN_INSTANCE_ID, + "status": "success", + "output_path": f's3://{BUCKET}/comfy/{s3_out_path}', + "temp_path": f's3://{BUCKET}/comfy/{s3_temp_path}', + } + sen_finish_sqs_msg(prompt_id) + logger.info(f"execute inference response is {response_body}") + executing = False + return ok(response_body) + except Exception as ecp: + logger.info(f"exception occurred {ecp}") + resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail", + "message": f"exception occurred {ecp}"} + executing = False + return error(resp) + finally: + logger.info(f"gc check: {time.time()}") + try: + global last_call_time, gc_triggered + gc_triggered = False + if last_call_time is None: + logger.info(f"gc check last time is NONE") + last_call_time = time.time() + else: + if time.time() - last_call_time > GC_WAIT_TIME: + if not gc_triggered: + logger.info(f"gc start: {time.time()} - {last_call_time}") + e.reset() + comfy.model_management.cleanup_models() + gc.collect() + comfy.model_management.soft_empty_cache() + gc_triggered = True + logger.info(f"gc end: {time.time()} - {last_call_time}") + last_call_time = time.time() + else: + last_call_time = time.time() + logger.info(f"gc check end: {time.time()}") + except Exception as e: + logger.info(f"gc error: {e}") + + + def get_last_ddb_sync_record(): + sync_response = sync_table.query( + KeyConditionExpression=Key('endpoint_name').eq(ENDPOINT_NAME), + Limit=1, + ScanIndexForward=False + ) + latest_sync_record = sync_response['Items'][0] if ('Items' in sync_response + and len(sync_response['Items']) > 0) else None + if latest_sync_record: + logger.info(f"latest_sync_record is:{latest_sync_record}") + return latest_sync_record + + logger.info("no latest_sync_record found") + return None + + + def get_latest_ddb_instance_monitor_record(): + key_condition_expression = ('endpoint_name = :endpoint_name_val ' + 'AND gen_instance_id = :gen_instance_id_val') + expression_attribute_values = { + ':endpoint_name_val': ENDPOINT_NAME, + ':gen_instance_id_val': GEN_INSTANCE_ID + } + instance_monitor_response = instance_monitor_table.query( + KeyConditionExpression=key_condition_expression, + ExpressionAttributeValues=expression_attribute_values + ) + instance_monitor_record = instance_monitor_response['Items'][0] \ + if ('Items' in instance_monitor_response and len(instance_monitor_response['Items']) > 0) else None + + if instance_monitor_record: + logger.info(f"instance_monitor_record is {instance_monitor_record}") + return instance_monitor_record + + logger.info("no instance_monitor_record found") + return None + + + def save_sync_instance_monitor(last_sync_request_id: str, sync_status: str): + item = { + 'endpoint_id': ENDPOINT_ID, + 'endpoint_name': ENDPOINT_NAME, + 'gen_instance_id': GEN_INSTANCE_ID, + 'sync_status': sync_status, + 'last_sync_request_id': last_sync_request_id, + 'last_sync_time': datetime.datetime.now().isoformat(), + 'sync_list': [], + 'create_time': datetime.datetime.now().isoformat(), + 'last_heartbeat_time': datetime.datetime.now().isoformat() + } + save_resp = instance_monitor_table.put_item(Item=item) + logger.info(f"save instance item {save_resp}") + return save_resp + + + def update_sync_instance_monitor(instance_monitor_record): + # 更新记录 + update_expression = ("SET sync_status = :new_sync_status, last_sync_request_id = :sync_request_id, " + "sync_list = :sync_list, last_sync_time = :sync_time, last_heartbeat_time = :heartbeat_time") + expression_attribute_values = { + ":new_sync_status": instance_monitor_record['sync_status'], + ":sync_request_id": instance_monitor_record['last_sync_request_id'], + ":sync_list": instance_monitor_record['sync_list'], + ":sync_time": datetime.datetime.now().isoformat(), + ":heartbeat_time": datetime.datetime.now().isoformat(), + } + + response = instance_monitor_table.update_item( + Key={'endpoint_name': ENDPOINT_NAME, + 'gen_instance_id': GEN_INSTANCE_ID}, + UpdateExpression=update_expression, + ExpressionAttributeValues=expression_attribute_values + ) + logger.info(f"update_sync_instance_monitor :{response}") + return response + + + def sync_instance_monitor_status(need_save: bool): + try: + logger.info(f"sync_instance_monitor_status {datetime.datetime.now()}") + if need_save: + save_sync_instance_monitor('', 'init') + else: + update_expression = ("SET last_heartbeat_time = :heartbeat_time") + expression_attribute_values = { + ":heartbeat_time": datetime.datetime.now().isoformat(), + } + instance_monitor_table.update_item( + Key={'endpoint_name': ENDPOINT_NAME, + 'gen_instance_id': GEN_INSTANCE_ID}, + UpdateExpression=update_expression, + ExpressionAttributeValues=expression_attribute_values + ) + except Exception as e: + logger.info(f"sync_instance_monitor_status error :{e}") + + + @server.PromptServer.instance.routes.post("/reboot") + async def restart(self): + logger.debug(f"start to reboot!!!!!!!! {self}") + global executing + if executing is True: + logger.info(f"other inference doing cannot reboot!!!!!!!!") + return ok({"message": "other inference doing cannot reboot"}) + need_reboot = os.environ.get('NEED_REBOOT') + if need_reboot and need_reboot.lower() != 'true': + logger.info("no need to reboot by os") + return ok({"message": "no need to reboot by os"}) + global reboot + if reboot is False: + logger.info("no need to reboot by global constant") + return ok({"message": "no need to reboot by constant"}) + + logger.debug("rebooting !!!!!!!!") + try: + sys.stdout.close_log() + except Exception as e: + logger.info(f"error reboot!!!!!!!! {e}") + pass + return os.execv(sys.executable, [sys.executable] + sys.argv) + + + # must be sync invoke and use the env to check + @server.PromptServer.instance.routes.post("/sync_instance") + async def sync_instance(request): + if not BUCKET: + logger.error("No bucket provided ,wait and try again") + resp = {"status": "success", "message": "syncing"} + return ok(resp) + + if 'ALREADY_SYNC' in os.environ and os.environ.get('ALREADY_SYNC').lower() == 'false': + resp = {"status": "success", "message": "syncing"} + logger.error("other process doing ,wait and try again") + return ok(resp) + + os.environ['ALREADY_SYNC'] = 'false' + logger.info(f"sync_instance start !! {datetime.datetime.now().isoformat()} {request}") + try: + last_sync_record = get_last_ddb_sync_record() + if not last_sync_record: + logger.info("no last sync record found do not need sync") + sync_instance_monitor_status(True) + resp = {"status": "success", "message": "no sync"} + os.environ['ALREADY_SYNC'] = 'true' + return ok(resp) + + if ('request_id' in last_sync_record and last_sync_record['request_id'] + and os.environ.get('LAST_SYNC_REQUEST_ID') + and os.environ.get('LAST_SYNC_REQUEST_ID') == last_sync_record['request_id'] + and os.environ.get('LAST_SYNC_REQUEST_TIME') + and os.environ.get('LAST_SYNC_REQUEST_TIME') == str(last_sync_record['request_time'])): + logger.info("last sync record already sync by os check") + sync_instance_monitor_status(False) + resp = {"status": "success", "message": "no sync env"} + os.environ['ALREADY_SYNC'] = 'true' + return ok(resp) + + instance_monitor_record = get_latest_ddb_instance_monitor_record() + if not instance_monitor_record: + sync_already = await prepare_comfy_env(last_sync_record) + if sync_already: + logger.info("should init prepare instance_monitor_record") + sync_status = 'success' if sync_already else 'failed' + save_sync_instance_monitor(last_sync_record['request_id'], sync_status) + else: + sync_instance_monitor_status(False) + else: + if ('last_sync_request_id' in instance_monitor_record + and instance_monitor_record['last_sync_request_id'] + and instance_monitor_record['last_sync_request_id'] == last_sync_record['request_id'] + and instance_monitor_record['sync_status'] + and instance_monitor_record['sync_status'] == 'success' + and os.environ.get('LAST_SYNC_REQUEST_TIME') + and os.environ.get('LAST_SYNC_REQUEST_TIME') == str(last_sync_record['request_time'])): + logger.info("last sync record already sync") + sync_instance_monitor_status(False) + resp = {"status": "success", "message": "no sync ddb"} + os.environ['ALREADY_SYNC'] = 'true' + return ok(resp) + + sync_already = await prepare_comfy_env(last_sync_record) + instance_monitor_record['sync_status'] = 'success' if sync_already else 'failed' + instance_monitor_record['last_sync_request_id'] = last_sync_record['request_id'] + sync_list = instance_monitor_record['sync_list'] if ('sync_list' in instance_monitor_record + and instance_monitor_record['sync_list']) else [] + sync_list.append(last_sync_record['request_id']) + + instance_monitor_record['sync_list'] = sync_list + logger.info("should update prepare instance_monitor_record") + update_sync_instance_monitor(instance_monitor_record) + os.environ['ALREADY_SYNC'] = 'true' + resp = {"status": "success", "message": "sync"} + return ok(resp) + except Exception as e: + logger.info("exception occurred", e) + os.environ['ALREADY_SYNC'] = 'true' + resp = {"status": "fail", "message": "sync"} + return error(resp) + + + def validate_prompt_proxy(func): + def wrapper(*args, **kwargs): + # 在这里添加您的代理逻辑 + logger.info("validate_prompt_proxy start...") + # 调用原始函数 + result = func(*args, **kwargs) + # 在这里添加执行后的操作 + logger.info("validate_prompt_proxy end...") + return result + + return wrapper + + + execution.validate_prompt = validate_prompt_proxy(execution.validate_prompt) + + + def send_sync_proxy(func): + def wrapper(*args, **kwargs): + logger.debug(f"Sending sync request!!!!!!! {args}") + global need_sync + global prompt_id + logger.info(f"send_sync_proxy start... {need_sync},{prompt_id} {args}") + func(*args, **kwargs) + if need_sync and QUEUE_URL and REGION: + logger.debug(f"send_sync_proxy params... {QUEUE_URL},{REGION},{need_sync},{prompt_id}") + event = args[1] + data = args[2] + sid = args[3] if len(args) == 4 else None + message_body = {'prompt_id': prompt_id, 'event': event, 'data': data, 'sid': sid} + message_id = sen_sqs_msg(message_body, prompt_id) + logger.info(f'send_sync_proxy message_id :{message_id} message_body: {message_body}') + logger.debug(f"send_sync_proxy end...") + + return wrapper + + + server.PromptServer.send_sync = send_sync_proxy(server.PromptServer.send_sync) + + + def get_save_imge_path_proxy(func): + def wrapper(*args, **kwargs): + logger.info(f"get_save_imge_path_proxy args : {args} kwargs : {kwargs}") + full_output_folder, filename, counter, subfolder, filename_prefix = func(*args, **kwargs) + global prompt_id + filename_prefix_new = filename_prefix + "_" + str(prompt_id) + logger.info(f"get_save_imge_path_proxy filename_prefix new : {filename_prefix_new}") + return full_output_folder, filename, counter, subfolder, filename_prefix_new + + return wrapper + + + folder_paths.get_save_image_path = get_save_imge_path_proxy(folder_paths.get_save_image_path) diff --git a/build_scripts/inference/serve b/build_scripts/inference/serve index 84293c89..dfbbe4e0 100755 Binary files a/build_scripts/inference/serve and b/build_scripts/inference/serve differ diff --git a/build_scripts/inference/start.sh b/build_scripts/inference/start.sh index d930a03d..57d5ff82 100644 --- a/build_scripts/inference/start.sh +++ b/build_scripts/inference/start.sh @@ -280,7 +280,6 @@ comfy_launch(){ chmod -R +x venv/bin rm -rf /home/ubuntu/ComfyUI/custom_nodes/ComfyUI-AWS-Extension - rm /home/ubuntu/ComfyUI/custom_nodes/comfy_local_proxy.py source venv/bin/activate python /metrics.py & @@ -331,29 +330,12 @@ comfy_launch_from_public_s3(){ # -------------------- startup -------------------- -# if pipeline finished, it will be executed -#if [[ $IMAGE_URL == *"dev"* ]]; then -# download_conda -# if [ "$SERVICE_TYPE" == "sd" ]; then -# sd_install_build -# /serve trim_sd.sh -# sd_cache_endpoint -# sd_launch -# exit 1 -# else -# comfy_install_build -# /serve trim_comfy -# comfy_cache_endpoint -# comfy_launch -# exit 1 -# fi -#fi - ec2_start_process(){ set -euxo pipefail echo "---------------------------------------------------------------------------------" export LD_LIBRARY_PATH=$LD_PRELOAD download_conda + init_port=8187 for i in $(seq 1 "$PROCESS_NUMBER"); do init_port=$((init_port + 1)) @@ -381,7 +363,38 @@ ec2_start_process(){ done } -if [ -n "$COMFY_EC2" ]; then +if [ -n "$WORKFLOW_NAME" ]; then + start_at=$(date +%s) + s5cmd --log=error sync "s3://$S3_BUCKET_NAME/comfy/workflows/$WORKFLOW_NAME/*" "/home/ubuntu/ComfyUI/" + end_at=$(date +%s) + export DOWNLOAD_FILE_SECONDS=$((end_at-start_at)) + echo "download file: $DOWNLOAD_FILE_SECONDS seconds" + + cd "/home/ubuntu/ComfyUI" || exit 1 + + rm -rf web/extensions/ComfyLiterals + + chmod -R 777 "/home/ubuntu/ComfyUI" + chmod -R +x venv + source venv/bin/activate + + # on EC2 + if [ -n "$ON_EC2" ]; then + ec2_start_process + exit 1 + fi + + if [ -n "$ON_SAGEMAKER" ]; then + python3 serve.py + exit 1 + fi + + # on SageMaker + python3 serve.py + exit 1 +fi + +if [ -n "$ON_EC2" ]; then set -euxo pipefail cd /home/ubuntu || exit 1 @@ -409,8 +422,6 @@ if [ -n "$COMFY_EC2" ]; then export DECOMPRESS_SECONDS=$((end_at-start_at)) echo "decompress file: $DECOMPRESS_SECONDS seconds" - ls -la - rm ./ComfyUI/custom_nodes/comfy_sagemaker_proxy.py cd /home/ubuntu/ComfyUI || exit 1 @@ -439,28 +450,6 @@ if [ -n "$COMFY_EC2" ]; then exit 1 fi -if [ -n "$APP_SOURCE" ]; then - if [ -n "$APP_CWD" ]; then - start_at=$(date +%s) - s5cmd --log=error sync "s3://$S3_BUCKET_NAME/${APP_SOURCE}*" "$APP_CWD" - end_at=$(date +%s) - export DOWNLOAD_FILE_SECONDS=$((end_at-start_at)) - echo "download file: $DOWNLOAD_FILE_SECONDS seconds" - - cd "$APP_CWD" || exit 1 - - rm -rf web/extensions/ComfyLiterals - - chmod -R 777 "$APP_CWD" - chmod -R +x venv - - source venv/bin/activate - - python3 serve.py - - exit 1 - fi -fi if [ -f "/initiated_lock" ]; then echo "already initiated, start service directly..." diff --git a/build_scripts/install_comfy.sh b/build_scripts/install_comfy.sh index 809fd75a..a3a886b9 100755 --- a/build_scripts/install_comfy.sh +++ b/build_scripts/install_comfy.sh @@ -22,8 +22,7 @@ if [ -n "$ESD_COMMIT_ID" ]; then fi cp stable-diffusion-aws-extension/build_scripts/comfy/serve.py ComfyUI/ -cp stable-diffusion-aws-extension/build_scripts/comfy/comfy_sagemaker_proxy.py ComfyUI/custom_nodes/ -cp stable-diffusion-aws-extension/build_scripts/comfy/comfy_local_proxy.py ComfyUI/custom_nodes/ +cp stable-diffusion-aws-extension/build_scripts/comfy/comfy_proxy.py ComfyUI/custom_nodes/ cp -R stable-diffusion-aws-extension/build_scripts/comfy/ComfyUI-AWS-Extension ComfyUI/custom_nodes/ComfyUI-AWS-Extension rm -rf stable-diffusion-aws-extension @@ -40,7 +39,7 @@ git clone https://github.com/AIGODLIKE/AIGODLIKE-ComfyUI-Translation.git custom_ git clone https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite.git custom_nodes/ComfyUI-VideoHelperSuite git clone https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved.git custom_nodes/ComfyUI-AnimateDiff-Evolved -if [ "$ON_DOCKER" == "true" ]; then +if [ "$ON_SAGEMAKER" == "true" ]; then python3 -m venv venv source venv/bin/activate pip install --upgrade pip diff --git a/build_scripts/install_sd.sh b/build_scripts/install_sd.sh index 2d5a42f9..c0b0aa7e 100755 --- a/build_scripts/install_sd.sh +++ b/build_scripts/install_sd.sh @@ -32,7 +32,7 @@ if [ -n "$ESD_COMMIT_ID" ]; then fi # remove unused files for docker layer reuse -if [ "$ON_DOCKER" == "true" ]; then +if [ "$ON_SAGEMAKER" == "true" ]; then rm -rf docs rm -rf infrastructure rm -rf middleware_api diff --git a/workshop/comfy_start.sh b/comfy_start.sh similarity index 69% rename from workshop/comfy_start.sh rename to comfy_start.sh index 76ffa7df..78811baa 100755 --- a/workshop/comfy_start.sh +++ b/comfy_start.sh @@ -2,25 +2,26 @@ set -euxo pipefail -source /etc/environment +if [ -f "/etc/environment" ]; then + source /etc/environment +fi -export CONTAINER_NAME='comfy_ec2' +export CONTAINER_NAME='esd_comfy' export ACCOUNT_ID=$(aws sts get-caller-identity --query "Account" --output text) export AWS_REGION=$(aws configure get region) -repository_name="comfy-ec2" -image="$ACCOUNT_ID.dkr.ecr.$AWS_REGION.amazonaws.com/$repository_name:latest" +image="$ACCOUNT_ID.dkr.ecr.$AWS_REGION.amazonaws.com/$CONTAINER_NAME:latest" docker stop "$CONTAINER_NAME" || true docker rm "$CONTAINER_NAME" || true # Check if the repository already exists -if aws ecr describe-repositories --region "$AWS_REGION" --repository-names "$repository_name" >/dev/null 2>&1; then - echo "ECR repository '$repository_name' already exists." +if aws ecr describe-repositories --region "$AWS_REGION" --repository-names "$CONTAINER_NAME" >/dev/null 2>&1; then + echo "ECR repository '$CONTAINER_NAME' already exists." else - echo "ECR repository '$repository_name' does not exist. Creating..." - aws ecr create-repository --repository-name --region "$AWS_REGION" "$repository_name" - echo "ECR repository '$repository_name' created successfully." + echo "ECR repository '$CONTAINER_NAME' does not exist. Creating..." + aws ecr create-repository --repository-name --region "$AWS_REGION" "$CONTAINER_NAME" | jq . + echo "ECR repository '$CONTAINER_NAME' created successfully." fi aws ecr get-login-password --region "$AWS_REGION" | docker login --username AWS --password-stdin "366590864501.dkr.ecr.$AWS_REGION.amazonaws.com" @@ -32,7 +33,7 @@ docker build -f Dockerfile.comfy \ image_hash=$(docker inspect "$image" | jq -r ".[0].Id") image_hash=${image_hash:7} -release_image="$ACCOUNT_ID.dkr.ecr.$AWS_REGION.amazonaws.com/$repository_name:$image_hash" +release_image="$ACCOUNT_ID.dkr.ecr.$AWS_REGION.amazonaws.com/$CONTAINER_NAME:$image_hash" docker tag "$image" "$release_image" aws ecr get-login-password --region "$AWS_REGION" | docker login --username AWS --password-stdin "$ACCOUNT_ID.dkr.ecr.$AWS_REGION.amazonaws.com" @@ -46,13 +47,16 @@ echo "Starting container..." local_volume="./ComfyUI" # local vol can be replace with your local directory +# -v ./build_scripts/comfy/comfy_local_proxy.py:/home/ubuntu/ComfyUI/custom_nodes/comfy_local_proxy.py \ + docker run -v ~/.aws:/root/.aws \ - -v $local_volume:/home/ubuntu/ComfyUI \ + -v "$local_volume":/home/ubuntu/ComfyUI \ + -v ./build_scripts/inference/start.sh:/start.sh \ --gpus all \ -e "IMAGE_HASH=$release_image" \ -e "ESD_VERSION=$ESD_VERSION" \ -e "SERVICE_TYPE=comfy" \ - -e "COMFY_EC2=true" \ + -e "ON_EC2=true" \ -e "S3_BUCKET_NAME=$COMFY_BUCKET_NAME" \ -e "AWS_REGION=$AWS_REGION" \ -e "AWS_DEFAULT_REGION=$AWS_REGION" \ @@ -61,6 +65,7 @@ docker run -v ~/.aws:/root/.aws \ -e "COMFY_ENDPOINT=$COMFY_ENDPOINT" \ -e "COMFY_BUCKET_NAME=$COMFY_BUCKET_NAME" \ -e "PROCESS_NUMBER=$PROCESS_NUMBER" \ + -e "WORKFLOW_NAME=$WORKFLOW_NAME" \ --name "$CONTAINER_NAME" \ -p 8188-8288:8188-8288 \ "$image" diff --git a/infrastructure/src/api/workflows/delete-workflows.ts b/infrastructure/src/api/workflows/delete-workflows.ts new file mode 100644 index 00000000..3b6b3598 --- /dev/null +++ b/infrastructure/src/api/workflows/delete-workflows.ts @@ -0,0 +1,170 @@ +import {PythonFunction} from '@aws-cdk/aws-lambda-python-alpha'; +import {Aws, aws_lambda, Duration} from 'aws-cdk-lib'; +import {JsonSchemaType, JsonSchemaVersion, LambdaIntegration, Model, Resource} from 'aws-cdk-lib/aws-apigateway'; +import {Table} from 'aws-cdk-lib/aws-dynamodb'; +import {Effect, PolicyStatement, Role, ServicePrincipal} from 'aws-cdk-lib/aws-iam'; +import {Architecture, LayerVersion, Runtime} from 'aws-cdk-lib/aws-lambda'; +import {Construct} from 'constructs'; +import {ApiModels} from '../../shared/models'; +import {SCHEMA_WORKFLOW_NAME} from '../../shared/schema'; +import {ApiValidators} from '../../shared/validator'; + +export interface DeleteWorkflowsApiProps { + router: Resource; + httpMethod: string; + workflowsTable: Table; + multiUserTable: Table; + commonLayer: LayerVersion; +} + +export class DeleteWorkflowsApi { + private readonly router: Resource; + private readonly httpMethod: string; + private readonly scope: Construct; + private readonly workflowsTable: Table; + private readonly multiUserTable: Table; + private readonly layer: LayerVersion; + private readonly baseId: string; + + constructor(scope: Construct, id: string, props: DeleteWorkflowsApiProps) { + this.scope = scope; + this.baseId = id; + this.router = props.router; + this.httpMethod = props.httpMethod; + this.workflowsTable = props.workflowsTable; + this.multiUserTable = props.multiUserTable; + this.layer = props.commonLayer; + + const lambdaFunction = this.apiLambda(); + + const lambdaIntegration = new LambdaIntegration( + lambdaFunction, + { + proxy: true, + }, + ); + + this.router.addMethod(this.httpMethod, lambdaIntegration, { + apiKeyRequired: true, + requestValidator: ApiValidators.bodyValidator, + requestModels: { + 'application/json': this.createRequestBodyModel(), + }, + operationName: 'DeleteWorkflows', + methodResponses: [ + ApiModels.methodResponses204(), + ApiModels.methodResponses400(), + ApiModels.methodResponses401(), + ApiModels.methodResponses403(), + ], + }); + } + + private iamRole(): Role { + + const newRole = new Role(this.scope, `${this.baseId}-role`, { + assumedBy: new ServicePrincipal('lambda.amazonaws.com'), + }); + + + newRole.addToPolicy(new PolicyStatement({ + actions: [ + 'dynamodb:Query', + 'dynamodb:GetItem', + 'dynamodb:PutItem', + 'dynamodb:DeleteItem', + 'dynamodb:UpdateItem', + 'dynamodb:Describe*', + 'dynamodb:List*', + ], + resources: [ + this.workflowsTable.tableArn, + `${this.workflowsTable.tableArn}/*`, + this.multiUserTable.tableArn, + ], + })); + + newRole.addToPolicy(new PolicyStatement({ + actions: [ + 's3:Get*', + 's3:List*', + 's3:PutObject', + 's3:GetObject', + 's3:DeleteObject', + ], + resources: [ + '*', + ], + })); + + newRole.addToPolicy(new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'cloudwatch:DeleteAlarms', + 'cloudwatch:DescribeAlarms', + 'cloudwatch:DeleteDashboards', + ], + resources: [ + '*', + ], + })); + + newRole.addToPolicy(new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'logs:CreateLogGroup', + 'logs:CreateLogStream', + 'logs:PutLogEvents', + 'logs:DeleteLogGroup', + ], + resources: [`arn:${Aws.PARTITION}:logs:${Aws.REGION}:${Aws.ACCOUNT_ID}:log-group:*:*`], + })); + + return newRole; + } + + private createRequestBodyModel(): Model { + return new Model(this.scope, `${this.baseId}-model`, { + restApi: this.router.api, + modelName: this.baseId, + description: `Request Model ${this.baseId}`, + schema: { + schema: JsonSchemaVersion.DRAFT7, + title: this.baseId, + type: JsonSchemaType.OBJECT, + properties: { + workflow_name_list: { + type: JsonSchemaType.ARRAY, + items: SCHEMA_WORKFLOW_NAME, + minItems: 1, + maxItems: 10, + }, + }, + required: [ + 'workflow_name_list', + ], + }, + contentType: 'application/json', + }); + } + + private apiLambda() { + return new PythonFunction(this.scope, `${this.baseId}-lambda`, { + entry: '../middleware_api/workflows', + architecture: Architecture.X86_64, + runtime: Runtime.PYTHON_3_10, + index: 'delete_workflows.py', + handler: 'handler', + timeout: Duration.seconds(900), + role: this.iamRole(), + memorySize: 2048, + tracing: aws_lambda.Tracing.ACTIVE, + layers: [this.layer], + environment:{ + WORKFLOWS_TABLE: this.workflowsTable.tableName, + } + }); + } + + +} diff --git a/infrastructure/src/shared/workflow.ts b/infrastructure/src/shared/workflow.ts index 3a920c5f..12947ff2 100644 --- a/infrastructure/src/shared/workflow.ts +++ b/infrastructure/src/shared/workflow.ts @@ -6,6 +6,7 @@ import {Construct} from 'constructs'; import {ResourceProvider} from './resource-provider'; import {CreateWorkflowApi} from "../api/workflows/create-workflow"; import {ListWorkflowsApi} from "../api/workflows/list-workflows"; +import {DeleteWorkflowsApi} from "../api/workflows/delete-workflows"; export interface WorkflowProps extends StackProps { routers: { [key: string]: Resource }; @@ -45,6 +46,17 @@ export class Workflow { }, ); + new DeleteWorkflowsApi( + scope, 'DeleteWorkflows', + { + workflowsTable: props.workflowsTable, + commonLayer: props.commonLayer, + multiUserTable: props.multiUserTable, + httpMethod: 'DELETE', + router: props.routers.workflows, + }, + ); + } } diff --git a/middleware_api/endpoints/create_endpoint.py b/middleware_api/endpoints/create_endpoint.py index a2e44f7c..c346057a 100644 --- a/middleware_api/endpoints/create_endpoint.py +++ b/middleware_api/endpoints/create_endpoint.py @@ -245,13 +245,13 @@ def _create_sagemaker_model(name, model_data_url, endpoint_name, endpoint_id, ev 'ESD_VERSION': esd_version, 'ESD_COMMIT_ID': esd_commit_id, 'SERVICE_TYPE': event.service_type, - 'ON_DOCKER': 'true', + 'ON_SAGEMAKER': 'true', 'AWS_REGION': aws_region, 'AWS_DEFAULT_REGION': aws_region, } if event.workflow: - environment['APP_SOURCE'] = event.workflow.s3_location + environment['WORKFLOW_NAME'] = event.workflow.name environment['APP_CWD'] = '/home/ubuntu/ComfyUI' primary_container = { diff --git a/middleware_api/service/oas.py b/middleware_api/service/oas.py index fee5856c..0b8145cd 100644 --- a/middleware_api/service/oas.py +++ b/middleware_api/service/oas.py @@ -485,6 +485,11 @@ operations = { tags=["Workflows"], description="List Workflows with Parameters", ), + "DeleteWorkflows": APISchema( + summary="Delete Workflows", + tags=["Workflows"], + description="Delete specify Workflows", + ), } diff --git a/middleware_api/workflows/delete_workflows.py b/middleware_api/workflows/delete_workflows.py new file mode 100644 index 00000000..00a4dbdd --- /dev/null +++ b/middleware_api/workflows/delete_workflows.py @@ -0,0 +1,47 @@ +import json +import logging +import os +from dataclasses import dataclass + +import boto3 +from aws_lambda_powertools import Tracer + +from common.ddb_service.client import DynamoDbUtilsService +from common.response import no_content +from libs.utils import response_error + +tracer = Tracer() +workflows_table = os.environ.get('WORKFLOWS_TABLE') + +logger = logging.getLogger(__name__) +logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR) + +ddb_service = DynamoDbUtilsService(logger=logger) +esd_version = os.environ.get("ESD_VERSION") +s3_resource = boto3.resource('s3') +bucket_name = os.environ.get('S3_BUCKET_NAME') +s3_bucket = s3_resource.Bucket(bucket_name) + + +@dataclass +class DeleteWorkflowsEvent: + workflow_name_list: [str] + + +@tracer.capture_lambda_handler +def handler(raw_event, ctx): + try: + logger.info(json.dumps(raw_event)) + + event = DeleteWorkflowsEvent(**json.loads(raw_event['body'])) + + for name in event.workflow_name_list: + s3_bucket.objects.filter(Prefix=f"comfy/workflows/{name}/").delete() + ddb_service.delete_item( + table=workflows_table, + keys={'name': name}, + ) + + return no_content(message="Workflows Deleted") + except Exception as e: + return response_error(e) diff --git a/scripts/api.py b/scripts/api.py index bca65644..a18389b0 100644 --- a/scripts/api.py +++ b/scripts/api.py @@ -526,8 +526,8 @@ try: import modules.script_callbacks as script_callbacks script_callbacks.on_app_started(sagemaker_api) - on_docker = os.environ.get('ON_DOCKER', "false") - if on_docker == "true": + on_sagemaker = os.environ.get('ON_SAGEMAKER', "false") + if on_sagemaker == "true": from modules import shared shared.opts.data.update(control_net_max_models_num=10) script_callbacks.on_app_started(move_model_to_tmp) diff --git a/scripts/main.py b/scripts/main.py index a7e6ad1b..a47addc6 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -855,8 +855,8 @@ class SageMakerUI(scripts.Script): return sagemaker_inputs_components def before_process(self, p, *args): - on_docker = os.environ.get('ON_DOCKER', "false") - if on_docker == "true": + on_sagemaker = os.environ.get('ON_SAGEMAKER', "false") + if on_sagemaker == "true": return # check if endpoint is InService @@ -1055,7 +1055,7 @@ def fetch_user_data(): time.sleep(30) -if os.environ.get('ON_DOCKER', "false") != "true": +if os.environ.get('ON_SAGEMAKER', "false") != "true": from aws_extension.auth_service.simple_cloud_auth import cloud_auth_manager if cloud_auth_manager.enableAuth: cmd_opts.gradio_auth = cloud_auth_manager.create_config() diff --git a/test/test_02_api_base/test_12_workflows.py b/test/test_02_api_base/test_12_workflows.py index b6e4fbb8..7ebb26bc 100644 --- a/test/test_02_api_base/test_12_workflows.py +++ b/test/test_02_api_base/test_12_workflows.py @@ -21,26 +21,45 @@ class TestComfyWorkflowApiBase: resp = self.api.create_workflow() assert resp.status_code == 403, resp.dumps() - def test_2_list_workflows_without_key(self): - resp = self.api.list_workflows() - assert resp.status_code == 403, resp.dumps() - - def test_3_create_workflow_with_bad_key(self): + def test_2_create_workflow_with_bad_key(self): headers = {'x-api-key': "bad_key"} resp = self.api.create_workflow(headers) assert resp.status_code == 403, resp.dumps() - def test_4_list_workflows_with_bad_key(self): - headers = {'x-api-key': "bad_key"} - resp = self.api.list_workflows(headers) - assert resp.status_code == 403, resp.dumps() - - def test_5_create_workflow_with_bad_request(self): + def test_3_create_workflow_with_bad_request(self): headers = {'x-api-key': config.api_key} resp = self.api.create_workflow(headers) assert resp.status_code == 400, resp.dumps() - def test_6_list_executes_with_ok(self): + def test_4_list_workflows_without_key(self): + resp = self.api.list_workflows() + assert resp.status_code == 403, resp.dumps() + + def test_5_list_workflows_with_bad_key(self): + headers = {'x-api-key': "bad_key"} + resp = self.api.list_workflows(headers) + assert resp.status_code == 403, resp.dumps() + + def test_6_list_workflows_with_ok(self): headers = {'x-api-key': config.api_key} resp = self.api.list_workflows(headers) assert resp.status_code == 200, resp.dumps() + + def test_7_delete_workflows_without_key(self): + resp = self.api.delete_workflows() + assert resp.status_code == 403, resp.dumps() + + def test_8_delete_workflows_with_bad_key(self): + headers = {'x-api-key': "bad_key"} + resp = self.api.delete_workflows(headers) + assert resp.status_code == 403, resp.dumps() + + def test_9_delete_workflows_with_ok(self): + headers = {'x-api-key': config.api_key} + data = { + "workflow_name_list": [ + "workflow_name" + ], + } + resp = self.api.delete_workflows(headers=headers, data=data) + assert resp.status_code == 204, resp.dumps() diff --git a/test/utils/api.py b/test/utils/api.py index f98baa6e..32b6bd36 100644 --- a/test/utils/api.py +++ b/test/utils/api.py @@ -92,6 +92,15 @@ class Api: data=data ) + def delete_workflows(self, headers=None, data=None): + return self.req( + "DELETE", + "workflows", + headers=headers, + operation_id='DeleteWorkflows', + data=data + ) + def delete_users(self, headers=None, data=None): return self.req( "DELETE", diff --git a/workshop/Dockerfile.comfy b/workshop/Dockerfile.comfy deleted file mode 100755 index f0bf8d0f..00000000 --- a/workshop/Dockerfile.comfy +++ /dev/null @@ -1,7 +0,0 @@ -ARG AWS_REGION -FROM 366590864501.dkr.ecr.$AWS_REGION.amazonaws.com/esd-inference:dev - -# TODO BYOC -RUN apt-get update -y && \ - apt-get install ffmpeg -y && \ - rm -rf /var/lib/apt/lists/* diff --git a/workshop/comfy.yaml b/workshop/comfy.yaml index 25afd67a..21ddc795 100644 --- a/workshop/comfy.yaml +++ b/workshop/comfy.yaml @@ -56,6 +56,9 @@ Parameters: - latest - dev Default: latest + WorkflowName: + Description: Bind Workflow Name + Type: String Mappings: RegionToAmiId: @@ -218,6 +221,7 @@ Resources: echo "export AWS_REGION=${AWS::Region}" >> /etc/environment echo "export PROCESS_NUMBER=${ProcessNumber}" >> /etc/environment echo "export ESD_VERSION=${EsdVersion}" >> /etc/environment + echo "export WORKFLOW_NAME=${WorkflowName}" >> /etc/environment source /etc/environment @@ -243,7 +247,7 @@ Resources: StartLimitIntervalSec=0 [Service] - WorkingDirectory=/root/stable-diffusion-aws-extension/workshop/ + WorkingDirectory=/root/stable-diffusion-aws-extension/ ExecStart=bash comfy_start.sh Type=simple Restart=always