stable-diffusion-aws-extension/build_scripts/comfy/comfy_proxy.py

1674 lines
80 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import concurrent.futures
import signal
import threading
import boto3
import requests
from aiohttp import web
import folder_paths
import server
from execution import PromptExecutor
import execution
import comfy
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
import subprocess
from dotenv import load_dotenv
import fcntl
import hashlib
import base64
import datetime
import json
import logging
import os
import sys
import tarfile
import time
import uuid
import gc
from dataclasses import dataclass
from typing import Optional
from boto3.dynamodb.conditions import Key
DISABLE_AWS_PROXY = 'DISABLE_AWS_PROXY'
sync_msg_list = []
client_release_map = {}
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
is_on_sagemaker = os.getenv('ON_SAGEMAKER') == 'true'
is_on_ec2 = os.getenv('ON_EC2') == 'true'
if is_on_ec2:
env_path = '/etc/environment'
if 'ENV_FILE_PATH' in os.environ and os.environ.get('ENV_FILE_PATH'):
env_path = os.environ.get('ENV_FILE_PATH')
load_dotenv('/etc/environment')
logger.info(f"env_path{env_path}")
env_keys = ['ENV_FILE_PATH', 'COMFY_INPUT_PATH', 'COMFY_MODEL_PATH', 'COMFY_NODE_PATH', 'COMFY_API_URL',
'COMFY_API_TOKEN', 'COMFY_ENDPOINT', 'COMFY_NEED_SYNC', 'COMFY_NEED_PREPARE', 'COMFY_BUCKET_NAME',
'MAX_WAIT_TIME', 'MSG_MAX_WAIT_TIME', 'THREAD_MAX_WAIT_TIME', DISABLE_AWS_PROXY, 'DISABLE_AUTO_SYNC']
for item in os.environ.keys():
if item in env_keys:
logger.info(f'evn key {item} {os.environ.get(item)}')
DIR3 = "input"
DIR1 = "models"
DIR2 = "custom_nodes"
if 'COMFY_INPUT_PATH' in os.environ and os.environ.get('COMFY_INPUT_PATH'):
DIR3 = os.environ.get('COMFY_INPUT_PATH')
if 'COMFY_MODEL_PATH' in os.environ and os.environ.get('COMFY_MODEL_PATH'):
DIR1 = os.environ.get('COMFY_MODEL_PATH')
if 'COMFY_NODE_PATH' in os.environ and os.environ.get('COMFY_NODE_PATH'):
DIR2 = os.environ.get('COMFY_NODE_PATH')
api_url = os.environ.get('COMFY_API_URL')
api_token = os.environ.get('COMFY_API_TOKEN')
comfy_endpoint = os.environ.get('COMFY_ENDPOINT')
comfy_need_sync = os.environ.get('COMFY_NEED_SYNC', True)
comfy_need_prepare = os.environ.get('COMFY_NEED_PREPARE', False)
bucket_name = os.environ.get('COMFY_BUCKET_NAME')
thread_max_wait_time = os.environ.get('THREAD_MAX_WAIT_TIME', 60)
max_wait_time = os.environ.get('MAX_WAIT_TIME', 86400)
msg_max_wait_time = os.environ.get('MSG_MAX_WAIT_TIME', 86400)
is_master_process = os.getenv('MASTER_PROCESS') == 'true'
program_name = os.getenv('PROGRAM_NAME')
no_need_sync_files = ['.autosave', '.cache', '.autosave1', '~', '.swp']
need_resend_msg_result = []
PREPARE_ID = 'default'
# additional
PREPARE_MODE = 'additional'
if not api_url:
raise ValueError("API_URL environment variables must be set.")
if not api_token:
raise ValueError("API_TOKEN environment variables must be set.")
if not comfy_endpoint:
raise ValueError("COMFY_ENDPOINT environment variables must be set.")
headers = {"x-api-key": api_token, "Content-Type": "application/json"}
def save_images_locally(response_json, local_folder):
try:
data = response_json.get("data", {})
prompt_id = data.get("prompt_id")
image_video_data = data.get("image_video_data", {})
if not prompt_id or not image_video_data:
logger.info("Missing prompt_id or image_video_data in the response.")
return
folder_path = os.path.join(local_folder, prompt_id)
os.makedirs(folder_path, exist_ok=True)
for image_name, image_url in image_video_data.items():
image_response = requests.get(image_url)
if image_response.status_code == 200:
image_path = os.path.join(folder_path, image_name)
with open(image_path, "wb") as image_file:
image_file.write(image_response.content)
logger.info(f"Image '{image_name}' saved to {image_path}")
else:
logger.info(
f"Failed to download image '{image_name}' from {image_url}. Status code: {image_response.status_code}")
except Exception as e:
logger.info(f"Error saving images locally: {e}")
def calculate_file_hash(file_path):
# 创建一个哈希对象
hasher = hashlib.sha256()
# 打开文件并逐块更新哈希对象
with open(file_path, 'rb') as file:
buffer = file.read(65536) # 64KB 的缓冲区大小
while len(buffer) > 0:
hasher.update(buffer)
buffer = file.read(65536)
# 返回哈希值的十六进制表示
return hasher.hexdigest()
def save_files(prefix, execute, key, target_dir, need_prefix):
if key in execute['data']:
temp_files = execute['data'][key]
for url in temp_files:
loca_file = get_file_name(url)
response = requests.get(url)
# if target_dir not exists, create it
if not os.path.exists(target_dir):
os.makedirs(target_dir)
logger.info(f"Saving file {loca_file} to {target_dir}")
if loca_file.endswith("output_images_will_be_put_here"):
continue
if need_prefix:
with open(f"./{target_dir}/{prefix}_{loca_file}", 'wb') as f:
f.write(response.content)
# current override exist
with open(f"./{target_dir}/{loca_file}", 'wb') as f:
f.write(response.content)
else:
with open(f"./{target_dir}/{loca_file}", 'wb') as f:
f.write(response.content)
def get_file_name(url: str):
file_name = url.split('/')[-1]
file_name = file_name.split('?')[0]
return file_name
def send_service_msg(server_use, msg):
event = msg.get('event')
data = msg.get('data')
sid = msg.get('sid') if 'sid' in msg else None
server_use.send_sync(event, data, sid)
def handle_sync_messages(server_use, msg_array):
already_synced = False
global sync_msg_list
for msg in msg_array:
for item_msg in msg:
event = item_msg.get('event')
data = item_msg.get('data')
sid = item_msg.get('sid') if 'sid' in item_msg else None
if data in sync_msg_list:
continue
sync_msg_list.append(data)
if event == 'finish':
already_synced = True
elif event == 'executed':
global need_resend_msg_result
need_resend_msg_result.append(msg)
server_use.send_sync(event, data, sid)
return already_synced
def execute_proxy(func):
def wrapper(*args, **kwargs):
def send_error_msg(executor, prompt_id, msg):
mes = {
"prompt_id": prompt_id,
"node_id": "",
"node_type": "on cloud",
"executed": [],
"exception_message": msg,
"exception_type": "",
"traceback": [],
"current_inputs": "",
"current_outputs": "",
}
executor.add_message("execution_error", mes, broadcast=True)
if 'True' == os.environ.get(DISABLE_AWS_PROXY):
logger.info("disabled aws proxy, use local")
return func(*args, **kwargs)
logger.info(f"enable aws proxy, use aws {comfy_endpoint}")
executor = args[0]
server_use = executor.server
prompt = args[1]
prompt_id = args[2]
extra_data = args[3]
client_id = extra_data['client_id'] if 'client_id' in extra_data else None
if not client_id:
send_error_msg(executor, prompt_id, f"Something went wrong when execute,please check your client_id and try again")
return web.Response()
global client_release_map
workflow_name = client_release_map.get(client_id)
if not workflow_name and not is_master_process:
send_error_msg(executor, prompt_id, f"Please choose a release env before you execute prompt")
return web.Response()
payload = {
"number": str(server.PromptServer.instance.number),
"prompt": prompt,
"prompt_id": prompt_id,
"extra_data": extra_data,
"endpoint_name": comfy_endpoint,
"need_prepare": comfy_need_prepare,
"need_sync": comfy_need_sync,
"multi_async": False,
"workflow_name": workflow_name,
}
def send_post_request(url, params):
logger.debug(f"sending post request {url} , params {params}")
get_response = requests.post(url, json=params, headers=headers)
return get_response
def send_get_request(url):
get_response = requests.get(url, headers=headers)
return get_response
def check_if_sync_is_already(url):
get_response = send_get_request(url)
prepare_response = get_response.json()
if (prepare_response['statusCode'] == 200 and 'data' in prepare_response and prepare_response['data']
and prepare_response['data']['prepareSuccess']):
logger.info(f"sync available")
return True
else:
logger.info(f"no sync available for {url} response {prepare_response}")
return False
logger.debug(f"payload is: {payload}")
is_synced = check_if_sync_is_already(f"{api_url}/prepare/{comfy_endpoint}")
if not is_synced:
logger.debug(f"is_synced is {is_synced} stop cloud prompt")
send_error_msg(executor, prompt_id, "Your local environment has not compleated to synchronized on cloud already. Please wait for a moment or click the 'Synchronize' button .")
return
with concurrent.futures.ThreadPoolExecutor() as executorThread:
execute_future = executorThread.submit(send_post_request, f"{api_url}/executes", payload)
save_already = False
if comfy_need_sync:
msg_future = executorThread.submit(send_get_request,
f"{api_url}/sync/{prompt_id}")
done, _ = concurrent.futures.wait([execute_future, msg_future],
return_when=concurrent.futures.ALL_COMPLETED)
already_synced = False
global sync_msg_list
sync_msg_list = []
for future in done:
if future == msg_future:
msg_response = future.result()
logger.info(f"get syc msg: {msg_response.json()}")
if msg_response.status_code == 200:
if 'data' not in msg_response.json() or not msg_response.json().get("data"):
logger.error("there is no response from sync msg by thread ")
time.sleep(1)
else:
logger.debug(msg_response.json())
already_synced = handle_sync_messages(server_use, msg_response.json().get("data"))
elif future == execute_future:
execute_resp = future.result()
logger.info(f"get execute status: {execute_resp.status_code}")
if execute_resp.status_code == 200 or execute_resp.status_code == 201 or execute_resp.status_code == 202:
i = thread_max_wait_time
while i > 0:
images_response = send_get_request(f"{api_url}/executes/{prompt_id}")
response = images_response.json()
logger.info(f"get execute images: {images_response.status_code}")
if images_response.status_code == 404:
logger.info("no images found already ,waiting sagemaker thread result .....")
time.sleep(3)
i = i - 2
elif response['data']['status'] == 'failed':
logger.error(f"there is no response on sagemaker from execute thread result !!!!!!!! ")
# send_error_msg(executor, prompt_id,
# f"There may be some errors when valid and execute the prompt on the cloud. Please check the SageMaker logs. error info: {response['data']['message']}")
# no need to send msg anymore
already_synced = True
break
elif response['data']['status'] != 'Completed' and response['data']['status'] != 'success':
logger.info(f"no images found already ,waiting sagemaker thread result, current status is {response['data']['status']}")
time.sleep(2)
i = i - 1
elif 'data' not in response or not response['data'] or 'status' not in response['data'] or not response['data']['status']:
logger.error(f"there is no response from execute thread result !!!!!!!! {response}")
# no need to send msg anymore
already_synced = True
# send_error_msg(executor, prompt_id,"There may be some errors when executing the prompt on cloud. No images or videos generated.")
break
else:
if ('temp_files' in images_response.json()['data'] and len(
images_response.json()['data']['temp_files']) > 0) or ((
'output_files' in images_response.json()['data'] and len(
images_response.json()['data']['output_files']) > 0)):
save_files(prompt_id, images_response.json(), 'temp_files', 'temp', False)
save_files(prompt_id, images_response.json(), 'output_files', 'output', True)
else:
send_error_msg(executor, prompt_id,
"There may be some errors when executing the prompt on the cloud. Please check the SageMaker logs.")
# no need to send msg anymore
already_synced = True
logger.debug(images_response.json())
save_already = True
break
else:
logger.error(f"get execute error: {execute_resp}")
# send_error_msg(executor, prompt_id, "Please valid your prompt and try again.")
# send_error_msg(executor, prompt_id,
# f"There may be some errors when valid and execute the prompt on the cloud. Please check the SageMaker logs. error info: {response['data']['message']}")
# no need to send msg anymore
already_synced = True
break
logger.debug(execute_resp.json())
m = msg_max_wait_time
while not already_synced:
msg_response = send_get_request(f"{api_url}/sync/{prompt_id}")
# logger.info(msg_response.json())
if msg_response.status_code == 200:
if m <= 0:
logger.error("there is no response from sync msg by timeout")
already_synced = True
elif 'data' not in msg_response.json() or not msg_response.json().get("data"):
logger.error("there is no response from sync msg")
time.sleep(1)
m = m - 1
else:
logger.debug(msg_response.json())
already_synced = handle_sync_messages(server_use, msg_response.json().get("data"))
logger.info(f"already_synced is :{already_synced}")
time.sleep(1)
m = m - 1
logger.info(f"check if images are already synced {save_already}")
if not save_already:
logger.info("check if images are not already synced, please wait")
execute_resp = execute_future.result()
logger.debug(f"execute result :{execute_resp.json()}")
if execute_resp.status_code == 200 or execute_resp.status_code == 201 or execute_resp.status_code == 202:
i = max_wait_time
while i > 0:
images_response = send_get_request(f"{api_url}/executes/{prompt_id}")
response = images_response.json()
logger.debug(response)
if images_response.status_code == 404:
logger.info(f"{i} no images found already ,waiting sagemaker result .....")
i = i - 2
time.sleep(3)
elif response['data']['status'] == 'failed':
logger.error(
f"there is no response on sagemaker from execute result !!!!!!!! ")
if 'message' in response['data'] and response['data']['message']:
send_error_msg(executor, prompt_id,
f"There may be some errors when valid or execute the prompt on the cloud. Please check the SageMaker logs. errors: {response['data']['message']}")
break
else:
logger.error(f"valid error on sagemaker :{response['data']}")
send_error_msg(executor, prompt_id,
f"There may be some errors when valid or execute the prompt on the cloud. errors")
break
elif response['data']['status'] != 'Completed' and response['data']['status'] != 'success':
logger.info(f"{i} images not already ,waiting sagemaker result .....{response['data']['status'] }")
i = i - 1
time.sleep(3)
elif 'data' not in response or not response['data'] or 'status' not in response['data'] or not response['data']['status']:
logger.info(f"{i} there is no response from sync executes {response}")
send_error_msg(executor, prompt_id, f"There may be some errors when executing the prompt on the cloud. No images or videos generated. {response['message']}")
break
elif response['data']['status'] == 'Completed' or response['data']['status'] == 'success':
if ('temp_files' in images_response.json()['data'] and len(images_response.json()['data']['temp_files']) > 0) or (('output_files' in images_response.json()['data'] and len(images_response.json()['data']['output_files']) > 0)):
save_files(prompt_id, images_response.json(), 'temp_files', 'temp', False)
save_files(prompt_id, images_response.json(), 'output_files', 'output', True)
break
else:
send_error_msg(executor, prompt_id,
"There may be some errors when executing the prompt on the cloud. Please check the SageMaker logs.")
break
else:
# logger.info(
# f"{i} images not already other,waiting sagemaker result .....{response}")
# i = i - 1
# time.sleep(3)
send_error_msg(executor, prompt_id,
"You have some errors when execute prompt on cloud . Please check your sagemaker logs.")
break
else:
logger.error(f"get execute error: {execute_resp}")
send_error_msg(executor, prompt_id, "Please valid your prompt and try again.")
logger.info("execute finished")
executorThread.shutdown()
return wrapper
PromptExecutor.execute = execute_proxy(PromptExecutor.execute)
def send_sync_proxy(func):
def wrapper(*args, **kwargs):
logger.info(f"Sending sync request----- {args}")
return func(*args, **kwargs)
return wrapper
server.PromptServer.send_sync = send_sync_proxy(server.PromptServer.send_sync)
def compress_and_upload(folder_path, prepare_version):
for subdir in next(os.walk(folder_path))[1]:
subdir_path = os.path.join(folder_path, subdir)
tar_filename = f"{subdir}.tar.gz"
logger.info(f"Compressing the {tar_filename}")
# 创建 tar 压缩文件
with tarfile.open(tar_filename, "w:gz") as tar:
tar.add(subdir_path, arcname=os.path.basename(subdir_path))
s5cmd_syn_node_command = f's5cmd --log=error cp {tar_filename} "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/custom_nodes/"'
logger.info(s5cmd_syn_node_command)
os.system(s5cmd_syn_node_command)
logger.info(f"rm {tar_filename}")
os.remove(tar_filename)
# for root, dirs, files in os.walk(folder_path):
# for directory in dirs:
# dir_path = os.path.join(root, directory)
# logger.info(f"Compressing the {dir_path}")
# tar_filename = f"{directory}.tar.gz"
# tar_filepath = os.path.join(root, tar_filename)
# with tarfile.open(tar_filepath, "w:gz") as tar:
# tar.add(dir_path, arcname=os.path.basename(dir_path))
# s5cmd_syn_node_command = f's5cmd --log=error cp {tar_filepath} "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"'
# logger.info(s5cmd_syn_node_command)
# os.system(s5cmd_syn_node_command)
# logger.info(f"rm {tar_filepath}")
# os.remove(tar_filepath)
def sync_default_files():
try:
timestamp = str(int(time.time() * 1000))
prepare_version = PREPARE_ID if PREPARE_MODE == 'additional' else timestamp
need_prepare = True
prepare_type = 'default'
need_reboot = True
logger.info(f" sync custom nodes files")
# s5cmd_syn_node_command = f's5cmd --log=error sync --delete=true --exclude="*comfy_local_proxy.py" {DIR2}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"'
# logger.info(f"sync custom_nodes files start {s5cmd_syn_node_command}")
# os.system(s5cmd_syn_node_command)
compress_and_upload(f"{DIR2}", prepare_version)
logger.info(f" sync input files")
s5cmd_syn_input_command = f's5cmd --log=error sync --delete=true {DIR3}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/input/"'
logger.info(f"sync input files start {s5cmd_syn_input_command}")
os.system(s5cmd_syn_input_command)
logger.info(f" sync models files")
s5cmd_syn_model_command = f's5cmd --log=error sync --delete=true {DIR1}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/models/"'
logger.info(f"sync models files start {s5cmd_syn_model_command}")
os.system(s5cmd_syn_model_command)
logger.info(f"Files changed in:: {need_prepare} {DIR2} {DIR1} {DIR3}")
url = api_url + "prepare"
logger.info(f"URL:{url}")
data = {"endpoint_name": comfy_endpoint, "need_reboot": need_reboot, "prepare_id": prepare_version,
"prepare_type": prepare_type}
logger.info(f"prepare params Data: {json.dumps(data, indent=4)}")
result = subprocess.run(["curl", "--location", "--request", "POST", url, "--header",
f"x-api-key: {api_token}", "--data-raw", json.dumps(data)],
capture_output=True, text=True)
logger.info(result.stdout)
return result.stdout
except Exception as e:
logger.info(f"sync_files error {e}")
return None
def sync_files(filepath, is_folder, is_auto):
try:
directory = os.path.dirname(filepath)
logger.info(f"Directory changed in: {directory} {filepath}")
if not directory:
logger.info("root path no need to sync files by duplicate opt")
return None
logger.info(f"Files changed in: {filepath}")
timestamp = str(int(time.time() * 1000))
need_prepare = False
prepare_type = 'default'
need_reboot = False
for ignore_item in no_need_sync_files:
if filepath.endswith(ignore_item):
logger.info(f"no need to sync files by ignore files {filepath} ends by {ignore_item}")
return None
prepare_version = PREPARE_ID if PREPARE_MODE == 'additional' else timestamp
if (str(directory).endswith(f"{DIR2}" if DIR2.startswith("/") else f"/{DIR2}")
or str(filepath) == DIR2 or str(filepath) == f'./{DIR2}' or f"{DIR2}/" in filepath):
logger.info(f" sync custom nodes files: {filepath}")
s5cmd_syn_node_command = f's5cmd --log=error sync --delete=true --exclude="*comfy_local_proxy.py" {DIR2}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/custom_nodes/"'
# s5cmd_syn_node_command = f'aws s3 sync {DIR2}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"'
# s5cmd_syn_node_command = f's5cmd sync {DIR2}/* "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"'
# custom_node文件夹有变化 稍后再同步
if is_auto and not is_folder_unlocked(directory):
logger.info("sync custom_nodes files is changing ,waiting.... ")
return None
logger.info("sync custom_nodes files start")
logger.info(s5cmd_syn_node_command)
os.system(s5cmd_syn_node_command)
need_prepare = True
need_reboot = True
prepare_type = 'nodes'
elif (str(directory).endswith(f"{DIR3}" if DIR3.startswith("/") else f"/{DIR3}")
or str(filepath) == DIR3 or str(filepath) == f'./{DIR3}' or f"{DIR3}/" in filepath):
logger.info(f" sync input files: {filepath}")
s5cmd_syn_input_command = f's5cmd --log=error sync --delete=true {DIR3}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/input/"'
# 判断文件写完后再同步
if is_auto:
if bool(is_folder):
can_sync = is_folder_unlocked(filepath)
else:
can_sync = is_file_unlocked(filepath)
if not can_sync:
logger.info("sync input files is changing ,waiting.... ")
return None
logger.info("sync input files start")
logger.info(s5cmd_syn_input_command)
os.system(s5cmd_syn_input_command)
need_prepare = True
prepare_type = 'inputs'
elif (str(directory).endswith(f"{DIR1}" if DIR1.startswith("/") else f"/{DIR1}")
or str(filepath) == DIR1 or str(filepath) == f'./{DIR1}' or f"{DIR1}/" in filepath):
logger.info(f" sync models files: {filepath}")
s5cmd_syn_model_command = f's5cmd --log=error sync --delete=true {DIR1}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/models/"'
# 判断文件写完后再同步
if is_auto:
if bool(is_folder):
can_sync = is_folder_unlocked(filepath)
else:
can_sync = is_file_unlocked(filepath)
# logger.info(f'is folder {directory} {is_folder} can_sync {can_sync}')
if not can_sync:
logger.info("sync input models is changing ,waiting.... ")
return None
logger.info("sync models files start")
logger.info(s5cmd_syn_model_command)
os.system(s5cmd_syn_model_command)
need_prepare = True
prepare_type = 'models'
logger.info(f"Files changed in:: {need_prepare} {str(directory)} {DIR2} {DIR1} {DIR3}")
if need_prepare:
url = api_url + "prepare"
logger.info(f"URL:{url}")
data = {"endpoint_name": comfy_endpoint, "need_reboot": need_reboot, "prepare_id": prepare_version,
"prepare_type": prepare_type}
logger.info(f"prepare params Data: {json.dumps(data, indent=4)}")
result = subprocess.run(["curl", "--location", "--request", "POST", url, "--header",
f"x-api-key: {api_token}", "--data-raw", json.dumps(data)],
capture_output=True, text=True)
logger.info(result.stdout)
return result.stdout
return None
except Exception as e:
logger.info(f"sync_files error {e}")
return None
def is_folder_unlocked(directory):
# logger.info("check if folder ")
event_handler = MyHandlerWithCheck()
observer = Observer()
observer.schedule(event_handler, directory, recursive=True)
observer.start()
time.sleep(1)
result = False
try:
if event_handler.file_changed:
logger.info(f"folder {directory} is still changing..")
event_handler.file_changed = False
time.sleep(1)
if event_handler.file_changed:
logger.info(f"folder {directory} is still still changing..")
else:
logger.info(f"folder {directory} changing stopped")
result = True
else:
logger.info(f"folder {directory} not stopped")
result = True
except (KeyboardInterrupt, Exception) as e:
logger.info(f"folder {directory} changed exception {e}")
observer.stop()
return result
def is_file_unlocked(file_path):
# logger.info("check if file ")
try:
initial_size = os.path.getsize(file_path)
initial_mtime = os.path.getmtime(file_path)
time.sleep(1)
current_size = os.path.getsize(file_path)
current_mtime = os.path.getmtime(file_path)
if current_size != initial_size or current_mtime != initial_mtime:
logger.info(f"unlock file error {file_path} is changing")
return False
with open(file_path, 'r') as f:
fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB)
return True
except (IOError, OSError, Exception) as e:
logger.info(f"unlock file error {file_path} is writing")
logger.error(e)
return False
class MyHandlerWithCheck(FileSystemEventHandler):
def __init__(self):
self.file_changed = False
def on_modified(self, event):
logger.info(f"custom_node folder is changing {event.src_path}")
self.file_changed = True
def on_deleted(self, event):
logger.info(f"custom_node folder is changing {event.src_path}")
self.file_changed = True
def on_created(self, event):
logger.info(f"custom_node folder is changing {event.src_path}")
self.file_changed = True
class MyHandlerWithSync(FileSystemEventHandler):
def on_modified(self, event):
logger.info(f"{datetime.datetime.now()} files modified start to sync {event}")
sync_files(event.src_path, event.is_directory, True)
def on_created(self, event):
logger.info(f"{datetime.datetime.now()} files added start to sync {event}")
sync_files(event.src_path, event.is_directory, True)
def on_deleted(self, event):
logger.info(f"{datetime.datetime.now()} files deleted start to sync {event}")
sync_files(event.src_path, event.is_directory, True)
stop_event = threading.Event()
def check_and_sync():
logger.info("check_and_sync start")
event_handler = MyHandlerWithSync()
observer = Observer()
try:
observer.schedule(event_handler, DIR1, recursive=True)
observer.schedule(event_handler, DIR2, recursive=True)
observer.schedule(event_handler, DIR3, recursive=True)
observer.start()
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("sync Shutting down please restart ComfyUI")
observer.stop()
observer.join()
def signal_handler(sig, frame):
logger.info("Received termination signal. Exiting...")
stop_event.set()
if os.environ.get('DISABLE_AUTO_SYNC') == 'false':
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
check_sync_thread = threading.Thread(target=check_and_sync)
check_sync_thread.start()
@server.PromptServer.instance.routes.post('/map_release')
async def map_release(request):
logger.info(f"start to map_release {request}")
json_data = await request.json()
if (not json_data or 'clientId' not in json_data or not json_data.get('clientId')
or 'releaseVersion' not in json_data or not json_data.get('releaseVersion')):
return web.Response(status=500, content_type='application/json', body=json.dumps({"result": False}))
global client_release_map
client_release_map[json_data.get('clientId')] = json_data.get('releaseVersion')
logger.info(f"client_release_map :{client_release_map}")
return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True}))
@server.PromptServer.instance.routes.get("/reboot")
async def restart(self):
if is_action_lock():
return web.Response(status=200, content_type='application/json',
body=json.dumps(
{"result": False, "message": "reboot is not allowed during sync workflow"}))
if not is_master_process:
return web.Response(status=200, content_type='application/json',
body=json.dumps({"result": False, "message": "only master can restart"}))
logger.info(f"start to reboot {self}")
try:
from xmlrpc.client import ServerProxy
server = ServerProxy('http://localhost:9001/RPC2')
server.supervisor.restart()
# server.supervisor.shutdown()
return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True}))
except Exception as e:
logger.info(f"error reboot {e}")
pass
return os.execv(sys.executable, [sys.executable] + sys.argv)
@server.PromptServer.instance.routes.get("/check_prepare")
async def check_prepare(self):
logger.info(f"start to check_prepare {self}")
try:
get_response = requests.get(f"{api_url}/prepare/{comfy_endpoint}", headers=headers)
response = get_response.json()
logger.info(f"check sync response is {response}")
if get_response.status_code == 200 and response['data']['prepareSuccess']:
return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True}))
else:
logger.info(f"check sync response is {response} {response['data']['prepareSuccess']}")
return web.Response(status=500, content_type='application/json', body=json.dumps({"result": False}))
except Exception as e:
logger.info(f"error restart {e}")
pass
return os.execv(sys.executable, [sys.executable] + sys.argv)
@server.PromptServer.instance.routes.get("/gc")
async def gc(self):
logger.info(f"start to gc {self}")
try:
logger.info(f"gc start: {time.time()}")
server_instance = server.PromptServer.instance
e = execution.PromptExecutor(server_instance)
e.reset()
comfy.model_management.cleanup_models()
gc.collect()
comfy.model_management.soft_empty_cache()
gc_triggered = True
logger.info(f"gc end: {time.time()}")
except Exception as e:
logger.info(f"error restart {e}")
pass
return os.execv(sys.executable, [sys.executable] + sys.argv)
def is_action_lock():
lock_file = f'/container/sync_lock'
if os.path.exists(lock_file):
with open(lock_file, 'r') as f:
content = f.read()
if content and not check_workflow_exists(content):
return True
return False
def action_lock(name: str):
lock_file = f'/container/sync_lock'
with open(lock_file, 'w') as f:
f.write(name)
def action_unlock():
lock_file = f'/container/sync_lock'
with open(lock_file, 'w') as f:
f.write("")
@server.PromptServer.instance.routes.get("/restart")
async def restart(self):
if is_action_lock():
return web.Response(status=200, content_type='application/json',
body=json.dumps(
{"result": False, "message": "restart is not allowed during sync workflow"}))
logger.info(f"start to restart {self}")
try:
sys.stdout.close_log()
except Exception as e:
logger.info(f"error restart {e}")
pass
return os.execv(sys.executable, [sys.executable] + sys.argv)
@server.PromptServer.instance.routes.get("/sync_env")
async def sync_env(request):
logger.info(f"start to sync_env {request}")
try:
result = sync_default_files()
logger.debug(f"sync result is :{result}")
return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True}))
except Exception as e:
logger.info(f"error sync_env {e}")
pass
return web.Response(status=500, content_type='application/json', body=json.dumps({"result": False}))
@server.PromptServer.instance.routes.post("/change_env")
async def change_env(request):
logger.info(f"start to change_env {request}")
json_data = await request.json()
if DISABLE_AWS_PROXY in json_data and json_data[DISABLE_AWS_PROXY] is not None:
logger.info(f"origin evn key DISABLE_AWS_PROXY is :{os.environ.get(DISABLE_AWS_PROXY)} {str(json_data[DISABLE_AWS_PROXY])}")
os.environ[DISABLE_AWS_PROXY] = str(json_data[DISABLE_AWS_PROXY])
logger.info(f"now evn key DISABLE_AWS_PROXY is :{os.environ.get(DISABLE_AWS_PROXY)}")
return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True}))
@server.PromptServer.instance.routes.get("/get_env")
async def get_env(request):
env = os.environ.get(DISABLE_AWS_PROXY, 'False')
return web.Response(status=200, content_type='application/json', body=json.dumps({"env": env}))
def get_directory_size(directory):
total_size = 0
for dirpath, dirnames, filenames in os.walk(directory):
for filename in filenames:
filepath = os.path.join(dirpath, filename)
if not os.path.islink(filepath): # 检查文件是否不是符号链接
total_size += os.path.getsize(filepath)
return total_size
@server.PromptServer.instance.routes.post("/workflows")
async def release_workflow(request):
if is_action_lock():
return web.Response(status=200, content_type='application/json',
body=json.dumps(
{"result": False, "message": "release is not allowed during sync workflow"}))
if not is_master_process:
return web.Response(status=200, content_type='application/json',
body=json.dumps({"result": False, "message": "only master can release workflow"}))
logger.info(f"start to release workflow {request}")
try:
json_data = await request.json()
if 'name' not in json_data or not json_data['name']:
return web.Response(status=200, content_type='application/json',
body=json.dumps({"result": False, "message": f"name is required"}))
workflow_name = json_data['name']
if workflow_name == 'default':
return web.Response(status=200, content_type='application/json',
body=json.dumps({"result": False, "message": f"{workflow_name} is not allowed"}))
payload_json = ''
if 'payload_json' in json_data:
payload_json = json_data['payload_json']
if check_workflow_exists(workflow_name):
return web.Response(status=200, content_type='application/json',
body=json.dumps({"result": False, "message": f"{workflow_name} already exists"}))
start_time = time.time()
action_lock(workflow_name)
base_image = os.getenv('BASE_IMAGE')
subprocess.check_output(f"echo {workflow_name} > /container/image_target_name", shell=True)
subprocess.check_output(f"echo {base_image} > /container/image_base", shell=True)
cur_workflow_name = os.getenv('WORKFLOW_NAME')
source_path = f"/container/workflows/{cur_workflow_name}"
print(f"source_path is {source_path}")
total_size_bytes = get_directory_size(source_path)
source_size = round(total_size_bytes / (1024 ** 3), 2)
s5cmd_sync_command = (f's5cmd sync '
f'--delete=true '
f'--exclude="*comfy.tar" '
f'--exclude="*.log" '
f'--exclude="*__pycache__*" '
f'--exclude="*/ComfyUI/input/*" '
f'--exclude="*/ComfyUI/output/*" '
f'"{source_path}/*" '
f'"s3://{bucket_name}/comfy/workflows/{workflow_name}/"')
s5cmd_lock_command = (f'echo "lock" > lock && '
f's5cmd sync lock s3://{bucket_name}/comfy/workflows/{workflow_name}/lock')
logger.info(f"sync workflows files start {s5cmd_sync_command}")
subprocess.check_output(s5cmd_sync_command, shell=True)
subprocess.check_output(s5cmd_lock_command, shell=True)
end_time = time.time()
cost_time = end_time - start_time
image_hash = os.getenv('IMAGE_HASH')
image_uri = f"{image_hash}:{workflow_name}"
data = {
"payload_json": payload_json,
"image_uri": image_uri,
"name": workflow_name,
"size": str(source_size),
}
get_response = requests.post(f"{api_url}/workflows", headers=headers, data=json.dumps(data))
response = get_response.json()
logger.info(f"release workflow response is {response}")
action_unlock()
return web.Response(status=200, content_type='application/json',
body=json.dumps({"result": True, "message": "success", "cost_time": cost_time}))
except Exception as e:
logger.info(e)
action_unlock()
return web.Response(status=500, content_type='application/json',
body=json.dumps({"result": False, "message": 'Release workflow failed'}))
@server.PromptServer.instance.routes.put("/workflows")
async def switch_workflow(request):
if is_action_lock():
return web.Response(status=200, content_type='application/json',
body=json.dumps(
{"result": False, "message": "switch is not allowed during sync workflow"}))
try:
json_data = await request.json()
if 'name' not in json_data or not json_data['name']:
raise ValueError("name is required")
workflow_name = json_data['name']
if workflow_name == os.getenv('WORKFLOW_NAME'):
return web.Response(status=200, content_type='application/json',
body=json.dumps({"result": False, "message": "workflow is already in use"}))
if workflow_name == 'default' and not is_master_process:
return web.Response(status=200, content_type='application/json',
body=json.dumps({"result": False, "message": "slave can not use default workflow"}))
if workflow_name != 'default':
if not check_workflow_exists(workflow_name):
return web.Response(status=200, content_type='application/json',
body=json.dumps({"result": False, "message": f"{workflow_name} not exists"}))
name_file = os.getenv('WORKFLOW_NAME_FILE')
subprocess.check_output(f"echo {workflow_name} > {name_file}", shell=True)
subprocess.run(["pkill", "-f", "python3"])
return web.Response(status=200, content_type='application/json',
body=json.dumps({"result": True, "message": "Please wait to restart"}))
except Exception as e:
logger.info(e)
return web.Response(status=500, content_type='application/json',
body=json.dumps({"result": False, "message": 'Switch workflow failed'}))
@server.PromptServer.instance.routes.get("/workflows")
async def get_workflows(request):
try:
workflow_name = os.getenv('WORKFLOW_NAME')
response = requests.get(f"{api_url}/workflows", headers=headers, params={"limit": 1000})
if response.status_code != 200:
return web.Response(status=500, content_type='application/json',
body=json.dumps({"result": False, "message": 'Get workflows failed'}))
data = response.json()['data']
workflows = data['workflows']
list = []
for workflow in workflows:
if workflow['status'] != 'Enabled':
continue
list.append({
"name": workflow['name'],
"size": workflow['size'],
"payload_json": workflow['payload_json'],
"in_use": workflow['name'] == workflow_name
})
data['current_workflow'] = workflow_name
data['workflows'] = list
return web.Response(status=200, content_type='application/json',
body=json.dumps({"result": True, "data": data}))
except Exception as e:
logger.info(e)
return web.Response(status=500, content_type='application/json',
body=json.dumps({"result": False, "message": 'Switch workflow failed'}))
def check_workflow_exists(name: str):
get_response = requests.get(f"{api_url}/workflows/{name}", headers=headers)
return get_response.status_code == 200
def restore_commands():
subprocess.run(["sleep", "5"])
subprocess.run(["pkill", "-f", "python3"])
@server.PromptServer.instance.routes.post("/restore")
async def release_rebuild_workflow(request):
if os.getenv('WORKFLOW_NAME') != 'default':
return web.Response(status=200, content_type='application/json',
body=json.dumps({"result": False, "message": "only default workflow can be restored"}))
if is_action_lock():
return web.Response(status=200, content_type='application/json',
body=json.dumps(
{"result": False, "message": "restore is not allowed during sync workflow"}))
if not is_master_process:
return web.Response(status=200, content_type='application/json',
body=json.dumps({"result": False, "message": "only master can restore comfy"}))
logger.info(f"start to restore EC2 {request}")
try:
os.system("mv /container/workflows/default/ComfyUI/models/checkpoints/v1-5-pruned-emaonly.ckpt /")
os.system("mv /container/workflows/default/ComfyUI/models/animatediff_models/mm_sd_v15_v2.ckpt /")
os.system("rm -rf /container/workflows/default")
os.system("mkdir -p /container/workflows/default")
os.system("tar --overwrite -xf /container/default.tar -C /container/workflows/default/")
os.system("mkdir -p /container/workflows/default/ComfyUI/models/checkpoints")
os.system("mkdir -p /container/workflows/default/ComfyUI/models/animatediff_models")
os.system("mv /v1-5-pruned-emaonly.ckpt /container/workflows/default/ComfyUI/models/checkpoints/")
os.system("mv /mm_sd_v15_v2.ckpt /container/workflows/default/ComfyUI/models/animatediff_models/")
thread = threading.Thread(target=restore_commands)
thread.start()
return web.Response(status=200, content_type='application/json',
body=json.dumps({"result": True, "message": "comfy will be restored in 5 seconds"}))
except Exception as e:
return web.Response(status=500, content_type='application/json',
body=json.dumps({"result": False, "message": e}))
if is_on_sagemaker:
global need_sync
global prompt_id
global executing
executing = False
global reboot
reboot = False
global last_call_time
last_call_time = None
global gc_triggered
gc_triggered = False
REGION = os.environ.get('AWS_REGION')
BUCKET = os.environ.get('S3_BUCKET_NAME')
QUEUE_URL = os.environ.get('COMFY_QUEUE_URL')
GEN_INSTANCE_ID = os.environ.get('ENDPOINT_INSTANCE_ID') if 'ENDPOINT_INSTANCE_ID' in os.environ and os.environ.get('ENDPOINT_INSTANCE_ID') else str(uuid.uuid4())
ENDPOINT_NAME = os.environ.get('ENDPOINT_NAME')
ENDPOINT_ID = os.environ.get('ENDPOINT_ID')
INSTANCE_MONITOR_TABLE_NAME = os.environ.get('COMFY_INSTANCE_MONITOR_TABLE')
SYNC_TABLE_NAME = os.environ.get('COMFY_SYNC_TABLE')
dynamodb = boto3.resource('dynamodb', region_name=REGION)
sync_table = dynamodb.Table(SYNC_TABLE_NAME)
instance_monitor_table = dynamodb.Table(INSTANCE_MONITOR_TABLE_NAME)
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.INFO)
ROOT_PATH = '/home/ubuntu/ComfyUI'
sqs_client = boto3.client('sqs', region_name=REGION)
GC_WAIT_TIME = 1800
def print_env():
for key, value in os.environ.items():
logger.info(f"{key}: {value}")
@dataclass
class ComfyResponse:
statusCode: int
message: str
body: Optional[dict]
def ok(body: dict):
return web.Response(status=200, content_type='application/json', body=json.dumps(body))
def error(body: dict):
# TODO 500 -》200 because of need resp anyway not exception
return web.Response(status=200, content_type='application/json', body=json.dumps(body))
def sen_sqs_msg(message_body, prompt_id_key):
response = sqs_client.send_message(
QueueUrl=QUEUE_URL,
MessageBody=json.dumps(message_body),
MessageGroupId=prompt_id_key
)
message_id = response['MessageId']
return message_id
def sen_finish_sqs_msg(prompt_id_key):
global need_sync
# logger.info(f"sen_finish_sqs_msg start... {need_sync},{prompt_id_key}")
if need_sync and QUEUE_URL and REGION:
message_body = {'prompt_id': prompt_id_key, 'event': 'finish', 'data': {"node": None, "prompt_id": prompt_id_key},
'sid': None}
message_id = sen_sqs_msg(message_body, prompt_id_key)
logger.info(f"finish message sent {message_id}")
async def prepare_comfy_env(sync_item: dict):
try:
request_id = sync_item['request_id']
logger.info(f"prepare_environment start sync_item:{sync_item}")
prepare_type = sync_item['prepare_type']
rlt = True
if prepare_type in ['default', 'models']:
sync_models_rlt = sync_s3_files_or_folders_to_local(f'{request_id}/models/*', f'{ROOT_PATH}/models', False)
if not sync_models_rlt:
rlt = False
if prepare_type in ['default', 'inputs']:
sync_inputs_rlt = sync_s3_files_or_folders_to_local(f'{request_id}/input/*', f'{ROOT_PATH}/input', False)
if not sync_inputs_rlt:
rlt = False
if prepare_type in ['default', 'nodes']:
sync_nodes_rlt = sync_s3_files_or_folders_to_local(f'{request_id}/custom_nodes/*',
f'{ROOT_PATH}/custom_nodes', True)
if not sync_nodes_rlt:
rlt = False
if prepare_type == 'custom':
sync_source_path = sync_item['s3_source_path']
local_target_path = sync_item['local_target_path']
if not sync_source_path or not local_target_path:
logger.info("s3_source_path and local_target_path should not be empty")
else:
sync_rlt = sync_s3_files_or_folders_to_local(sync_source_path,
f'{ROOT_PATH}/{local_target_path}', False)
if not sync_rlt:
rlt = False
elif prepare_type == 'other':
sync_script = sync_item['sync_script']
logger.info("sync_script")
# sync_script.startswith('s5cmd') 不允许
try:
if sync_script and (sync_script.startswith("python3 -m pip") or sync_script.startswith("python -m pip")
or sync_script.startswith("pip install") or sync_script.startswith("apt")
or sync_script.startswith("os.environ") or sync_script.startswith("ls")
or sync_script.startswith("env") or sync_script.startswith("source")
or sync_script.startswith("curl") or sync_script.startswith("wget")
or sync_script.startswith("print") or sync_script.startswith("cat")
or sync_script.startswith("sudo chmod") or sync_script.startswith("chmod")
or sync_script.startswith("/home/ubuntu/ComfyUI/venv/bin/python")):
os.system(sync_script)
elif sync_script and (sync_script.startswith("export ") and len(sync_script.split(" ")) > 2):
sync_script_key = sync_script.split(" ")[1]
sync_script_value = sync_script.split(" ")[2]
os.environ[sync_script_key] = sync_script_value
logger.info(os.environ.get(sync_script_key))
except Exception as e:
logger.error(f"Exception while execute sync_scripts : {sync_script}")
rlt = False
need_reboot = True if ('need_reboot' in sync_item and sync_item['need_reboot']
and str(sync_item['need_reboot']).lower() == 'true')else False
global reboot
reboot = need_reboot
if need_reboot:
os.environ['NEED_REBOOT'] = 'true'
else:
os.environ['NEED_REBOOT'] = 'false'
logger.info("prepare_environment end")
os.environ['LAST_SYNC_REQUEST_ID'] = sync_item['request_id']
os.environ['LAST_SYNC_REQUEST_TIME'] = str(sync_item['request_time'])
return rlt
except Exception as e:
return False
def sync_s3_files_or_folders_to_local(s3_path, local_path, need_un_tar):
logger.info("sync_s3_models_or_inputs_to_local start")
# s5cmd_command = f'{ROOT_PATH}/tools/s5cmd sync "s3://{bucket_name}/{s3_path}/*" "{local_path}/"'
if need_un_tar:
s5cmd_command = f's5cmd sync "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
else:
s5cmd_command = f's5cmd sync --delete=true "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
# s5cmd_command = f's5cmd sync --delete=true "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
# s5cmd_command = f's5cmd sync "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
try:
logger.info(s5cmd_command)
os.system(s5cmd_command)
logger.info(f'Files copied from "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" to "{local_path}/"')
if need_un_tar:
for filename in os.listdir(local_path):
if filename.endswith(".tar.gz"):
tar_filepath = os.path.join(local_path, filename)
# extract_path = os.path.splitext(os.path.splitext(tar_filepath)[0])[0]
# os.makedirs(extract_path, exist_ok=True)
# logger.info(f'Extracting extract_path is {extract_path}')
with tarfile.open(tar_filepath, "r:gz") as tar:
for member in tar.getmembers():
tar.extract(member, path=local_path)
os.remove(tar_filepath)
logger.info(f'File {tar_filepath} extracted and removed')
return True
except Exception as e:
logger.info(f"Error executing s5cmd command: {e}")
return False
def sync_local_outputs_to_s3(s3_path, local_path):
logger.info("sync_local_outputs_to_s3 start")
s5cmd_command = f's5cmd sync "{local_path}/*" "s3://{BUCKET}/comfy/{s3_path}/" '
try:
logger.info(s5cmd_command)
os.system(s5cmd_command)
logger.info(f'Files copied local to "s3://{BUCKET}/comfy/{s3_path}/" to "{local_path}/"')
clean_cmd = f'rm -rf {local_path}'
os.system(clean_cmd)
logger.info(f'Files removed from local {local_path}')
except Exception as e:
logger.info(f"Error executing s5cmd command: {e}")
def sync_local_outputs_to_base64(local_path):
logger.info("sync_local_outputs_to_base64 start")
try:
result = {}
for root, dirs, files in os.walk(local_path):
for file in files:
file_path = os.path.join(root, file)
with open(file_path, "rb") as f:
file_content = f.read()
base64_content = base64.b64encode(file_content).decode('utf-8')
result[file] = base64_content
clean_cmd = f'rm -rf {local_path}'
os.system(clean_cmd)
logger.info(f'Files removed from local {local_path}')
return result
except Exception as e:
logger.info(f"Error executing s5cmd command: {e}")
return {}
@server.PromptServer.instance.routes.post("/execute_proxy")
async def execute_proxy(request):
json_data = await request.json()
if 'out_path' in json_data and json_data['out_path'] is not None:
out_path = json_data['out_path']
else:
out_path = None
logger.info(f"invocations start json_data:{json_data}")
global need_sync
need_sync = json_data["need_sync"]
global prompt_id
prompt_id = json_data["prompt_id"]
try:
global executing
if executing is True:
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
"message": "the environment is not ready valid[0] is false, need to resync"}
sen_finish_sqs_msg(prompt_id)
return error(resp)
executing = True
logger.info(
f'bucket_name: {BUCKET}, region: {REGION}')
if ('need_prepare' in json_data and json_data['need_prepare']
and 'prepare_props' in json_data and json_data['prepare_props']):
sync_already = await prepare_comfy_env(json_data['prepare_props'])
if not sync_already:
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
"message": "the environment is not ready with sync"}
executing = False
sen_finish_sqs_msg(prompt_id)
return error(resp)
server_instance = server.PromptServer.instance
if "number" in json_data:
number = float(json_data['number'])
server_instance.number = number
else:
number = server_instance.number
if "front" in json_data:
if json_data['front']:
number = -number
server_instance.number += 1
valid = execution.validate_prompt(json_data['prompt'])
logger.info(f"Validating prompt result is {valid}")
if not valid[0]:
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
"message": "the environment is not ready valid[0] is false, need to resync"}
executing = False
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
sen_finish_sqs_msg(prompt_id)
return error(resp)
# if len(valid) == 4 and len(valid[3]) > 0:
# logger.info(f"Validating prompt error there is something error because of :valid: {valid}")
# resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
# "message": f"the valid is error, need to resync or check the workflow :{valid}"}
# executing = False
# return error(resp)
extra_data = {}
client_id = ''
if "extra_data" in json_data:
extra_data = json_data["extra_data"]
if 'client_id' in extra_data and extra_data['client_id']:
client_id = extra_data['client_id']
if "client_id" in json_data and json_data["client_id"]:
extra_data["client_id"] = json_data["client_id"]
client_id = json_data["client_id"]
server_instance.client_id = client_id
prompt_id = json_data['prompt_id']
server_instance.last_prompt_id = prompt_id
e = execution.PromptExecutor(server_instance)
outputs_to_execute = valid[2]
e.execute(json_data['prompt'], prompt_id, extra_data, outputs_to_execute)
s3_out_path = f'output/{prompt_id}/{out_path}' if out_path is not None else f'output/{prompt_id}'
s3_temp_path = f'temp/{prompt_id}/{out_path}' if out_path is not None else f'temp/{prompt_id}'
local_out_path = f'{ROOT_PATH}/output/{out_path}' if out_path is not None else f'{ROOT_PATH}/output'
local_temp_path = f'{ROOT_PATH}/temp/{out_path}' if out_path is not None else f'{ROOT_PATH}/temp'
logger.info(f"s3_out_path is {s3_out_path} and s3_temp_path is {s3_temp_path} and local_out_path is {local_out_path} and local_temp_path is {local_temp_path}")
sync_local_outputs_to_s3(s3_out_path, local_out_path)
sync_local_outputs_to_s3(s3_temp_path, local_temp_path)
response_body = {
"prompt_id": prompt_id,
"instance_id": GEN_INSTANCE_ID,
"status": "success",
"output_path": f's3://{BUCKET}/comfy/{s3_out_path}',
"temp_path": f's3://{BUCKET}/comfy/{s3_temp_path}',
}
sen_finish_sqs_msg(prompt_id)
logger.info(f"execute inference response is {response_body}")
executing = False
return ok(response_body)
except Exception as ecp:
logger.info(f"exception occurred {ecp}")
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
"message": f"exception occurred {ecp}"}
executing = False
return error(resp)
finally:
logger.info(f"gc check: {time.time()}")
try:
global last_call_time, gc_triggered
gc_triggered = False
if last_call_time is None:
logger.info(f"gc check last time is NONE")
last_call_time = time.time()
else:
if time.time() - last_call_time > GC_WAIT_TIME:
if not gc_triggered:
logger.info(f"gc start: {time.time()} - {last_call_time}")
e.reset()
comfy.model_management.cleanup_models()
gc.collect()
comfy.model_management.soft_empty_cache()
gc_triggered = True
logger.info(f"gc end: {time.time()} - {last_call_time}")
last_call_time = time.time()
else:
last_call_time = time.time()
logger.info(f"gc check end: {time.time()}")
except Exception as e:
logger.info(f"gc error: {e}")
def get_last_ddb_sync_record():
sync_response = sync_table.query(
KeyConditionExpression=Key('endpoint_name').eq(ENDPOINT_NAME),
Limit=1,
ScanIndexForward=False
)
latest_sync_record = sync_response['Items'][0] if ('Items' in sync_response
and len(sync_response['Items']) > 0) else None
if latest_sync_record:
logger.info(f"latest_sync_record is{latest_sync_record}")
return latest_sync_record
logger.info("no latest_sync_record found")
return None
def get_latest_ddb_instance_monitor_record():
key_condition_expression = ('endpoint_name = :endpoint_name_val '
'AND gen_instance_id = :gen_instance_id_val')
expression_attribute_values = {
':endpoint_name_val': ENDPOINT_NAME,
':gen_instance_id_val': GEN_INSTANCE_ID
}
instance_monitor_response = instance_monitor_table.query(
KeyConditionExpression=key_condition_expression,
ExpressionAttributeValues=expression_attribute_values
)
instance_monitor_record = instance_monitor_response['Items'][0] \
if ('Items' in instance_monitor_response and len(instance_monitor_response['Items']) > 0) else None
if instance_monitor_record:
logger.info(f"instance_monitor_record is {instance_monitor_record}")
return instance_monitor_record
logger.info("no instance_monitor_record found")
return None
def save_sync_instance_monitor(last_sync_request_id: str, sync_status: str):
item = {
'endpoint_id': ENDPOINT_ID,
'endpoint_name': ENDPOINT_NAME,
'gen_instance_id': GEN_INSTANCE_ID,
'sync_status': sync_status,
'last_sync_request_id': last_sync_request_id,
'last_sync_time': datetime.datetime.now().isoformat(),
'sync_list': [],
'create_time': datetime.datetime.now().isoformat(),
'last_heartbeat_time': datetime.datetime.now().isoformat()
}
save_resp = instance_monitor_table.put_item(Item=item)
logger.info(f"save instance item {save_resp}")
return save_resp
def update_sync_instance_monitor(instance_monitor_record):
# 更新记录
update_expression = ("SET sync_status = :new_sync_status, last_sync_request_id = :sync_request_id, "
"sync_list = :sync_list, last_sync_time = :sync_time, last_heartbeat_time = :heartbeat_time")
expression_attribute_values = {
":new_sync_status": instance_monitor_record['sync_status'],
":sync_request_id": instance_monitor_record['last_sync_request_id'],
":sync_list": instance_monitor_record['sync_list'],
":sync_time": datetime.datetime.now().isoformat(),
":heartbeat_time": datetime.datetime.now().isoformat(),
}
response = instance_monitor_table.update_item(
Key={'endpoint_name': ENDPOINT_NAME,
'gen_instance_id': GEN_INSTANCE_ID},
UpdateExpression=update_expression,
ExpressionAttributeValues=expression_attribute_values
)
logger.info(f"update_sync_instance_monitor :{response}")
return response
def sync_instance_monitor_status(need_save: bool):
try:
logger.info(f"sync_instance_monitor_status {datetime.datetime.now()}")
if need_save:
save_sync_instance_monitor('', 'init')
else:
update_expression = ("SET last_heartbeat_time = :heartbeat_time")
expression_attribute_values = {
":heartbeat_time": datetime.datetime.now().isoformat(),
}
instance_monitor_table.update_item(
Key={'endpoint_name': ENDPOINT_NAME,
'gen_instance_id': GEN_INSTANCE_ID},
UpdateExpression=update_expression,
ExpressionAttributeValues=expression_attribute_values
)
except Exception as e:
logger.info(f"sync_instance_monitor_status error :{e}")
@server.PromptServer.instance.routes.post("/reboot")
async def restart(self):
logger.debug(f"start to reboot!!!!!!!! {self}")
global executing
if executing is True:
logger.info(f"other inference doing cannot reboot!!!!!!!!")
return ok({"message": "other inference doing cannot reboot"})
need_reboot = os.environ.get('NEED_REBOOT')
if need_reboot and need_reboot.lower() != 'true':
logger.info("no need to reboot by os")
return ok({"message": "no need to reboot by os"})
global reboot
if reboot is False:
logger.info("no need to reboot by global constant")
return ok({"message": "no need to reboot by constant"})
logger.debug("rebooting !!!!!!!!")
try:
sys.stdout.close_log()
except Exception as e:
logger.info(f"error reboot!!!!!!!! {e}")
pass
return os.execv(sys.executable, [sys.executable] + sys.argv)
# must be sync invoke and use the env to check
@server.PromptServer.instance.routes.post("/sync_instance")
async def sync_instance(request):
if not BUCKET:
logger.error("No bucket provided ,wait and try again")
resp = {"status": "success", "message": "syncing"}
return ok(resp)
if 'ALREADY_SYNC' in os.environ and os.environ.get('ALREADY_SYNC').lower() == 'false':
resp = {"status": "success", "message": "syncing"}
logger.error("other process doing ,wait and try again")
return ok(resp)
os.environ['ALREADY_SYNC'] = 'false'
logger.info(f"sync_instance start {datetime.datetime.now().isoformat()} {request}")
try:
last_sync_record = get_last_ddb_sync_record()
if not last_sync_record:
logger.info("no last sync record found do not need sync")
sync_instance_monitor_status(True)
resp = {"status": "success", "message": "no sync"}
os.environ['ALREADY_SYNC'] = 'true'
return ok(resp)
if ('request_id' in last_sync_record and last_sync_record['request_id']
and os.environ.get('LAST_SYNC_REQUEST_ID')
and os.environ.get('LAST_SYNC_REQUEST_ID') == last_sync_record['request_id']
and os.environ.get('LAST_SYNC_REQUEST_TIME')
and os.environ.get('LAST_SYNC_REQUEST_TIME') == str(last_sync_record['request_time'])):
logger.info("last sync record already sync by os check")
sync_instance_monitor_status(False)
resp = {"status": "success", "message": "no sync env"}
os.environ['ALREADY_SYNC'] = 'true'
return ok(resp)
instance_monitor_record = get_latest_ddb_instance_monitor_record()
if not instance_monitor_record:
sync_already = await prepare_comfy_env(last_sync_record)
if sync_already:
logger.info("should init prepare instance_monitor_record")
sync_status = 'success' if sync_already else 'failed'
save_sync_instance_monitor(last_sync_record['request_id'], sync_status)
else:
sync_instance_monitor_status(False)
else:
if ('last_sync_request_id' in instance_monitor_record
and instance_monitor_record['last_sync_request_id']
and instance_monitor_record['last_sync_request_id'] == last_sync_record['request_id']
and instance_monitor_record['sync_status']
and instance_monitor_record['sync_status'] == 'success'
and os.environ.get('LAST_SYNC_REQUEST_TIME')
and os.environ.get('LAST_SYNC_REQUEST_TIME') == str(last_sync_record['request_time'])):
logger.info("last sync record already sync")
sync_instance_monitor_status(False)
resp = {"status": "success", "message": "no sync ddb"}
os.environ['ALREADY_SYNC'] = 'true'
return ok(resp)
sync_already = await prepare_comfy_env(last_sync_record)
instance_monitor_record['sync_status'] = 'success' if sync_already else 'failed'
instance_monitor_record['last_sync_request_id'] = last_sync_record['request_id']
sync_list = instance_monitor_record['sync_list'] if ('sync_list' in instance_monitor_record
and instance_monitor_record['sync_list']) else []
sync_list.append(last_sync_record['request_id'])
instance_monitor_record['sync_list'] = sync_list
logger.info("should update prepare instance_monitor_record")
update_sync_instance_monitor(instance_monitor_record)
os.environ['ALREADY_SYNC'] = 'true'
resp = {"status": "success", "message": "sync"}
return ok(resp)
except Exception as e:
logger.info("exception occurred", e)
os.environ['ALREADY_SYNC'] = 'true'
resp = {"status": "fail", "message": "sync"}
return error(resp)
def validate_prompt_proxy(func):
def wrapper(*args, **kwargs):
# 在这里添加您的代理逻辑
logger.info("validate_prompt_proxy start...")
# 调用原始函数
result = func(*args, **kwargs)
# 在这里添加执行后的操作
logger.info("validate_prompt_proxy end...")
return result
return wrapper
execution.validate_prompt = validate_prompt_proxy(execution.validate_prompt)
def send_sync_proxy(func):
def wrapper(*args, **kwargs):
logger.debug(f"Sending sync request!!!!!!! {args}")
global need_sync
global prompt_id
logger.info(f"send_sync_proxy start... {need_sync},{prompt_id} {args}")
func(*args, **kwargs)
if need_sync and QUEUE_URL and REGION:
logger.debug(f"send_sync_proxy params... {QUEUE_URL},{REGION},{need_sync},{prompt_id}")
event = args[1]
data = args[2]
sid = args[3] if len(args) == 4 else None
message_body = {'prompt_id': prompt_id, 'event': event, 'data': data, 'sid': sid}
message_id = sen_sqs_msg(message_body, prompt_id)
logger.info(f'send_sync_proxy message_id :{message_id} message_body: {message_body}')
logger.debug(f"send_sync_proxy end...")
return wrapper
server.PromptServer.send_sync = send_sync_proxy(server.PromptServer.send_sync)
def get_save_imge_path_proxy(func):
def wrapper(*args, **kwargs):
logger.info(f"get_save_imge_path_proxy args : {args} kwargs : {kwargs}")
full_output_folder, filename, counter, subfolder, filename_prefix = func(*args, **kwargs)
global prompt_id
filename_prefix_new = filename_prefix + "_" + str(prompt_id)
logger.info(f"get_save_imge_path_proxy filename_prefix new : {filename_prefix_new}")
return full_output_folder, filename, counter, subfolder, filename_prefix_new
return wrapper
folder_paths.get_save_image_path = get_save_imge_path_proxy(folder_paths.get_save_image_path)