1674 lines
80 KiB
Python
Executable File
1674 lines
80 KiB
Python
Executable File
import concurrent.futures
|
||
import signal
|
||
import threading
|
||
|
||
import boto3
|
||
import requests
|
||
from aiohttp import web
|
||
|
||
import folder_paths
|
||
import server
|
||
from execution import PromptExecutor
|
||
import execution
|
||
import comfy
|
||
|
||
from watchdog.observers import Observer
|
||
from watchdog.events import FileSystemEventHandler
|
||
import subprocess
|
||
from dotenv import load_dotenv
|
||
|
||
import fcntl
|
||
import hashlib
|
||
|
||
import base64
|
||
import datetime
|
||
import json
|
||
import logging
|
||
import os
|
||
import sys
|
||
import tarfile
|
||
import time
|
||
import uuid
|
||
import gc
|
||
from dataclasses import dataclass
|
||
from typing import Optional
|
||
|
||
|
||
from boto3.dynamodb.conditions import Key
|
||
|
||
|
||
DISABLE_AWS_PROXY = 'DISABLE_AWS_PROXY'
|
||
sync_msg_list = []
|
||
client_release_map = {}
|
||
|
||
logging.basicConfig(level=logging.DEBUG)
|
||
logger = logging.getLogger(__name__)
|
||
logger.setLevel(logging.INFO)
|
||
|
||
is_on_sagemaker = os.getenv('ON_SAGEMAKER') == 'true'
|
||
is_on_ec2 = os.getenv('ON_EC2') == 'true'
|
||
|
||
if is_on_ec2:
|
||
env_path = '/etc/environment'
|
||
|
||
if 'ENV_FILE_PATH' in os.environ and os.environ.get('ENV_FILE_PATH'):
|
||
env_path = os.environ.get('ENV_FILE_PATH')
|
||
|
||
load_dotenv('/etc/environment')
|
||
logger.info(f"env_path{env_path}")
|
||
|
||
env_keys = ['ENV_FILE_PATH', 'COMFY_INPUT_PATH', 'COMFY_MODEL_PATH', 'COMFY_NODE_PATH', 'COMFY_API_URL',
|
||
'COMFY_API_TOKEN', 'COMFY_ENDPOINT', 'COMFY_NEED_SYNC', 'COMFY_NEED_PREPARE', 'COMFY_BUCKET_NAME',
|
||
'MAX_WAIT_TIME', 'MSG_MAX_WAIT_TIME', 'THREAD_MAX_WAIT_TIME', DISABLE_AWS_PROXY, 'DISABLE_AUTO_SYNC']
|
||
|
||
for item in os.environ.keys():
|
||
if item in env_keys:
|
||
logger.info(f'evn key: {item} {os.environ.get(item)}')
|
||
|
||
DIR3 = "input"
|
||
DIR1 = "models"
|
||
DIR2 = "custom_nodes"
|
||
|
||
if 'COMFY_INPUT_PATH' in os.environ and os.environ.get('COMFY_INPUT_PATH'):
|
||
DIR3 = os.environ.get('COMFY_INPUT_PATH')
|
||
if 'COMFY_MODEL_PATH' in os.environ and os.environ.get('COMFY_MODEL_PATH'):
|
||
DIR1 = os.environ.get('COMFY_MODEL_PATH')
|
||
if 'COMFY_NODE_PATH' in os.environ and os.environ.get('COMFY_NODE_PATH'):
|
||
DIR2 = os.environ.get('COMFY_NODE_PATH')
|
||
|
||
|
||
api_url = os.environ.get('COMFY_API_URL')
|
||
api_token = os.environ.get('COMFY_API_TOKEN')
|
||
comfy_endpoint = os.environ.get('COMFY_ENDPOINT')
|
||
comfy_need_sync = os.environ.get('COMFY_NEED_SYNC', True)
|
||
comfy_need_prepare = os.environ.get('COMFY_NEED_PREPARE', False)
|
||
bucket_name = os.environ.get('COMFY_BUCKET_NAME')
|
||
thread_max_wait_time = os.environ.get('THREAD_MAX_WAIT_TIME', 60)
|
||
max_wait_time = os.environ.get('MAX_WAIT_TIME', 86400)
|
||
msg_max_wait_time = os.environ.get('MSG_MAX_WAIT_TIME', 86400)
|
||
is_master_process = os.getenv('MASTER_PROCESS') == 'true'
|
||
program_name = os.getenv('PROGRAM_NAME')
|
||
no_need_sync_files = ['.autosave', '.cache', '.autosave1', '~', '.swp']
|
||
|
||
need_resend_msg_result = []
|
||
PREPARE_ID = 'default'
|
||
# additional
|
||
PREPARE_MODE = 'additional'
|
||
|
||
if not api_url:
|
||
raise ValueError("API_URL environment variables must be set.")
|
||
|
||
if not api_token:
|
||
raise ValueError("API_TOKEN environment variables must be set.")
|
||
|
||
if not comfy_endpoint:
|
||
raise ValueError("COMFY_ENDPOINT environment variables must be set.")
|
||
|
||
headers = {"x-api-key": api_token, "Content-Type": "application/json"}
|
||
|
||
|
||
def save_images_locally(response_json, local_folder):
|
||
try:
|
||
data = response_json.get("data", {})
|
||
prompt_id = data.get("prompt_id")
|
||
image_video_data = data.get("image_video_data", {})
|
||
|
||
if not prompt_id or not image_video_data:
|
||
logger.info("Missing prompt_id or image_video_data in the response.")
|
||
return
|
||
|
||
folder_path = os.path.join(local_folder, prompt_id)
|
||
os.makedirs(folder_path, exist_ok=True)
|
||
|
||
for image_name, image_url in image_video_data.items():
|
||
image_response = requests.get(image_url)
|
||
if image_response.status_code == 200:
|
||
image_path = os.path.join(folder_path, image_name)
|
||
with open(image_path, "wb") as image_file:
|
||
image_file.write(image_response.content)
|
||
logger.info(f"Image '{image_name}' saved to {image_path}")
|
||
else:
|
||
logger.info(
|
||
f"Failed to download image '{image_name}' from {image_url}. Status code: {image_response.status_code}")
|
||
|
||
except Exception as e:
|
||
logger.info(f"Error saving images locally: {e}")
|
||
|
||
|
||
def calculate_file_hash(file_path):
|
||
# 创建一个哈希对象
|
||
hasher = hashlib.sha256()
|
||
# 打开文件并逐块更新哈希对象
|
||
with open(file_path, 'rb') as file:
|
||
buffer = file.read(65536) # 64KB 的缓冲区大小
|
||
while len(buffer) > 0:
|
||
hasher.update(buffer)
|
||
buffer = file.read(65536)
|
||
# 返回哈希值的十六进制表示
|
||
return hasher.hexdigest()
|
||
|
||
|
||
def save_files(prefix, execute, key, target_dir, need_prefix):
|
||
if key in execute['data']:
|
||
temp_files = execute['data'][key]
|
||
for url in temp_files:
|
||
loca_file = get_file_name(url)
|
||
response = requests.get(url)
|
||
# if target_dir not exists, create it
|
||
if not os.path.exists(target_dir):
|
||
os.makedirs(target_dir)
|
||
logger.info(f"Saving file {loca_file} to {target_dir}")
|
||
if loca_file.endswith("output_images_will_be_put_here"):
|
||
continue
|
||
if need_prefix:
|
||
with open(f"./{target_dir}/{prefix}_{loca_file}", 'wb') as f:
|
||
f.write(response.content)
|
||
# current override exist
|
||
with open(f"./{target_dir}/{loca_file}", 'wb') as f:
|
||
f.write(response.content)
|
||
else:
|
||
with open(f"./{target_dir}/{loca_file}", 'wb') as f:
|
||
f.write(response.content)
|
||
|
||
|
||
def get_file_name(url: str):
|
||
file_name = url.split('/')[-1]
|
||
file_name = file_name.split('?')[0]
|
||
return file_name
|
||
|
||
|
||
def send_service_msg(server_use, msg):
|
||
event = msg.get('event')
|
||
data = msg.get('data')
|
||
sid = msg.get('sid') if 'sid' in msg else None
|
||
server_use.send_sync(event, data, sid)
|
||
|
||
|
||
def handle_sync_messages(server_use, msg_array):
|
||
already_synced = False
|
||
global sync_msg_list
|
||
for msg in msg_array:
|
||
for item_msg in msg:
|
||
event = item_msg.get('event')
|
||
data = item_msg.get('data')
|
||
sid = item_msg.get('sid') if 'sid' in item_msg else None
|
||
if data in sync_msg_list:
|
||
continue
|
||
sync_msg_list.append(data)
|
||
if event == 'finish':
|
||
already_synced = True
|
||
elif event == 'executed':
|
||
global need_resend_msg_result
|
||
need_resend_msg_result.append(msg)
|
||
server_use.send_sync(event, data, sid)
|
||
|
||
return already_synced
|
||
|
||
|
||
def execute_proxy(func):
|
||
def wrapper(*args, **kwargs):
|
||
def send_error_msg(executor, prompt_id, msg):
|
||
mes = {
|
||
"prompt_id": prompt_id,
|
||
"node_id": "",
|
||
"node_type": "on cloud",
|
||
"executed": [],
|
||
"exception_message": msg,
|
||
"exception_type": "",
|
||
"traceback": [],
|
||
"current_inputs": "",
|
||
"current_outputs": "",
|
||
}
|
||
executor.add_message("execution_error", mes, broadcast=True)
|
||
|
||
if 'True' == os.environ.get(DISABLE_AWS_PROXY):
|
||
logger.info("disabled aws proxy, use local")
|
||
return func(*args, **kwargs)
|
||
logger.info(f"enable aws proxy, use aws {comfy_endpoint}")
|
||
executor = args[0]
|
||
server_use = executor.server
|
||
prompt = args[1]
|
||
prompt_id = args[2]
|
||
extra_data = args[3]
|
||
|
||
client_id = extra_data['client_id'] if 'client_id' in extra_data else None
|
||
if not client_id:
|
||
send_error_msg(executor, prompt_id, f"Something went wrong when execute,please check your client_id and try again")
|
||
return web.Response()
|
||
global client_release_map
|
||
workflow_name = client_release_map.get(client_id)
|
||
if not workflow_name and not is_master_process:
|
||
send_error_msg(executor, prompt_id, f"Please choose a release env before you execute prompt")
|
||
return web.Response()
|
||
|
||
payload = {
|
||
"number": str(server.PromptServer.instance.number),
|
||
"prompt": prompt,
|
||
"prompt_id": prompt_id,
|
||
"extra_data": extra_data,
|
||
"endpoint_name": comfy_endpoint,
|
||
"need_prepare": comfy_need_prepare,
|
||
"need_sync": comfy_need_sync,
|
||
"multi_async": False,
|
||
"workflow_name": workflow_name,
|
||
}
|
||
|
||
def send_post_request(url, params):
|
||
logger.debug(f"sending post request {url} , params {params}")
|
||
get_response = requests.post(url, json=params, headers=headers)
|
||
return get_response
|
||
|
||
def send_get_request(url):
|
||
get_response = requests.get(url, headers=headers)
|
||
return get_response
|
||
|
||
def check_if_sync_is_already(url):
|
||
get_response = send_get_request(url)
|
||
prepare_response = get_response.json()
|
||
if (prepare_response['statusCode'] == 200 and 'data' in prepare_response and prepare_response['data']
|
||
and prepare_response['data']['prepareSuccess']):
|
||
logger.info(f"sync available")
|
||
return True
|
||
else:
|
||
logger.info(f"no sync available for {url} response {prepare_response}")
|
||
return False
|
||
|
||
|
||
logger.debug(f"payload is: {payload}")
|
||
is_synced = check_if_sync_is_already(f"{api_url}/prepare/{comfy_endpoint}")
|
||
if not is_synced:
|
||
logger.debug(f"is_synced is {is_synced} stop cloud prompt")
|
||
send_error_msg(executor, prompt_id, "Your local environment has not compleated to synchronized on cloud already. Please wait for a moment or click the 'Synchronize' button .")
|
||
return
|
||
|
||
with concurrent.futures.ThreadPoolExecutor() as executorThread:
|
||
execute_future = executorThread.submit(send_post_request, f"{api_url}/executes", payload)
|
||
|
||
save_already = False
|
||
if comfy_need_sync:
|
||
msg_future = executorThread.submit(send_get_request,
|
||
f"{api_url}/sync/{prompt_id}")
|
||
done, _ = concurrent.futures.wait([execute_future, msg_future],
|
||
return_when=concurrent.futures.ALL_COMPLETED)
|
||
already_synced = False
|
||
global sync_msg_list
|
||
sync_msg_list = []
|
||
for future in done:
|
||
if future == msg_future:
|
||
msg_response = future.result()
|
||
logger.info(f"get syc msg: {msg_response.json()}")
|
||
if msg_response.status_code == 200:
|
||
if 'data' not in msg_response.json() or not msg_response.json().get("data"):
|
||
logger.error("there is no response from sync msg by thread ")
|
||
time.sleep(1)
|
||
else:
|
||
logger.debug(msg_response.json())
|
||
already_synced = handle_sync_messages(server_use, msg_response.json().get("data"))
|
||
elif future == execute_future:
|
||
execute_resp = future.result()
|
||
logger.info(f"get execute status: {execute_resp.status_code}")
|
||
if execute_resp.status_code == 200 or execute_resp.status_code == 201 or execute_resp.status_code == 202:
|
||
i = thread_max_wait_time
|
||
while i > 0:
|
||
images_response = send_get_request(f"{api_url}/executes/{prompt_id}")
|
||
response = images_response.json()
|
||
logger.info(f"get execute images: {images_response.status_code}")
|
||
if images_response.status_code == 404:
|
||
logger.info("no images found already ,waiting sagemaker thread result .....")
|
||
time.sleep(3)
|
||
i = i - 2
|
||
elif response['data']['status'] == 'failed':
|
||
logger.error(f"there is no response on sagemaker from execute thread result !!!!!!!! ")
|
||
# send_error_msg(executor, prompt_id,
|
||
# f"There may be some errors when valid and execute the prompt on the cloud. Please check the SageMaker logs. error info: {response['data']['message']}")
|
||
# no need to send msg anymore
|
||
already_synced = True
|
||
break
|
||
elif response['data']['status'] != 'Completed' and response['data']['status'] != 'success':
|
||
logger.info(f"no images found already ,waiting sagemaker thread result, current status is {response['data']['status']}")
|
||
time.sleep(2)
|
||
i = i - 1
|
||
elif 'data' not in response or not response['data'] or 'status' not in response['data'] or not response['data']['status']:
|
||
logger.error(f"there is no response from execute thread result !!!!!!!! {response}")
|
||
# no need to send msg anymore
|
||
already_synced = True
|
||
# send_error_msg(executor, prompt_id,"There may be some errors when executing the prompt on cloud. No images or videos generated.")
|
||
break
|
||
else:
|
||
if ('temp_files' in images_response.json()['data'] and len(
|
||
images_response.json()['data']['temp_files']) > 0) or ((
|
||
'output_files' in images_response.json()['data'] and len(
|
||
images_response.json()['data']['output_files']) > 0)):
|
||
save_files(prompt_id, images_response.json(), 'temp_files', 'temp', False)
|
||
save_files(prompt_id, images_response.json(), 'output_files', 'output', True)
|
||
else:
|
||
send_error_msg(executor, prompt_id,
|
||
"There may be some errors when executing the prompt on the cloud. Please check the SageMaker logs.")
|
||
# no need to send msg anymore
|
||
already_synced = True
|
||
logger.debug(images_response.json())
|
||
save_already = True
|
||
break
|
||
else:
|
||
logger.error(f"get execute error: {execute_resp}")
|
||
# send_error_msg(executor, prompt_id, "Please valid your prompt and try again.")
|
||
# send_error_msg(executor, prompt_id,
|
||
# f"There may be some errors when valid and execute the prompt on the cloud. Please check the SageMaker logs. error info: {response['data']['message']}")
|
||
# no need to send msg anymore
|
||
already_synced = True
|
||
break
|
||
logger.debug(execute_resp.json())
|
||
|
||
m = msg_max_wait_time
|
||
while not already_synced:
|
||
msg_response = send_get_request(f"{api_url}/sync/{prompt_id}")
|
||
# logger.info(msg_response.json())
|
||
if msg_response.status_code == 200:
|
||
if m <= 0:
|
||
logger.error("there is no response from sync msg by timeout")
|
||
already_synced = True
|
||
elif 'data' not in msg_response.json() or not msg_response.json().get("data"):
|
||
logger.error("there is no response from sync msg")
|
||
time.sleep(1)
|
||
m = m - 1
|
||
else:
|
||
logger.debug(msg_response.json())
|
||
already_synced = handle_sync_messages(server_use, msg_response.json().get("data"))
|
||
logger.info(f"already_synced is :{already_synced}")
|
||
time.sleep(1)
|
||
m = m - 1
|
||
logger.info(f"check if images are already synced {save_already}")
|
||
|
||
if not save_already:
|
||
logger.info("check if images are not already synced, please wait")
|
||
execute_resp = execute_future.result()
|
||
logger.debug(f"execute result :{execute_resp.json()}")
|
||
if execute_resp.status_code == 200 or execute_resp.status_code == 201 or execute_resp.status_code == 202:
|
||
i = max_wait_time
|
||
while i > 0:
|
||
images_response = send_get_request(f"{api_url}/executes/{prompt_id}")
|
||
response = images_response.json()
|
||
logger.debug(response)
|
||
if images_response.status_code == 404:
|
||
logger.info(f"{i} no images found already ,waiting sagemaker result .....")
|
||
i = i - 2
|
||
time.sleep(3)
|
||
elif response['data']['status'] == 'failed':
|
||
logger.error(
|
||
f"there is no response on sagemaker from execute result !!!!!!!! ")
|
||
if 'message' in response['data'] and response['data']['message']:
|
||
send_error_msg(executor, prompt_id,
|
||
f"There may be some errors when valid or execute the prompt on the cloud. Please check the SageMaker logs. errors: {response['data']['message']}")
|
||
break
|
||
else:
|
||
logger.error(f"valid error on sagemaker :{response['data']}")
|
||
send_error_msg(executor, prompt_id,
|
||
f"There may be some errors when valid or execute the prompt on the cloud. errors")
|
||
break
|
||
elif response['data']['status'] != 'Completed' and response['data']['status'] != 'success':
|
||
logger.info(f"{i} images not already ,waiting sagemaker result .....{response['data']['status'] }")
|
||
i = i - 1
|
||
time.sleep(3)
|
||
elif 'data' not in response or not response['data'] or 'status' not in response['data'] or not response['data']['status']:
|
||
logger.info(f"{i} there is no response from sync executes {response}")
|
||
send_error_msg(executor, prompt_id, f"There may be some errors when executing the prompt on the cloud. No images or videos generated. {response['message']}")
|
||
break
|
||
elif response['data']['status'] == 'Completed' or response['data']['status'] == 'success':
|
||
if ('temp_files' in images_response.json()['data'] and len(images_response.json()['data']['temp_files']) > 0) or (('output_files' in images_response.json()['data'] and len(images_response.json()['data']['output_files']) > 0)):
|
||
save_files(prompt_id, images_response.json(), 'temp_files', 'temp', False)
|
||
save_files(prompt_id, images_response.json(), 'output_files', 'output', True)
|
||
break
|
||
else:
|
||
send_error_msg(executor, prompt_id,
|
||
"There may be some errors when executing the prompt on the cloud. Please check the SageMaker logs.")
|
||
break
|
||
else:
|
||
# logger.info(
|
||
# f"{i} images not already other,waiting sagemaker result .....{response}")
|
||
# i = i - 1
|
||
# time.sleep(3)
|
||
send_error_msg(executor, prompt_id,
|
||
"You have some errors when execute prompt on cloud . Please check your sagemaker logs.")
|
||
break
|
||
else:
|
||
logger.error(f"get execute error: {execute_resp}")
|
||
send_error_msg(executor, prompt_id, "Please valid your prompt and try again.")
|
||
logger.info("execute finished")
|
||
executorThread.shutdown()
|
||
|
||
return wrapper
|
||
|
||
|
||
PromptExecutor.execute = execute_proxy(PromptExecutor.execute)
|
||
|
||
|
||
def send_sync_proxy(func):
|
||
def wrapper(*args, **kwargs):
|
||
logger.info(f"Sending sync request----- {args}")
|
||
return func(*args, **kwargs)
|
||
return wrapper
|
||
|
||
|
||
server.PromptServer.send_sync = send_sync_proxy(server.PromptServer.send_sync)
|
||
|
||
|
||
def compress_and_upload(folder_path, prepare_version):
|
||
for subdir in next(os.walk(folder_path))[1]:
|
||
subdir_path = os.path.join(folder_path, subdir)
|
||
tar_filename = f"{subdir}.tar.gz"
|
||
logger.info(f"Compressing the {tar_filename}")
|
||
|
||
# 创建 tar 压缩文件
|
||
with tarfile.open(tar_filename, "w:gz") as tar:
|
||
tar.add(subdir_path, arcname=os.path.basename(subdir_path))
|
||
s5cmd_syn_node_command = f's5cmd --log=error cp {tar_filename} "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/custom_nodes/"'
|
||
logger.info(s5cmd_syn_node_command)
|
||
os.system(s5cmd_syn_node_command)
|
||
logger.info(f"rm {tar_filename}")
|
||
os.remove(tar_filename)
|
||
|
||
# for root, dirs, files in os.walk(folder_path):
|
||
# for directory in dirs:
|
||
# dir_path = os.path.join(root, directory)
|
||
# logger.info(f"Compressing the {dir_path}")
|
||
# tar_filename = f"{directory}.tar.gz"
|
||
# tar_filepath = os.path.join(root, tar_filename)
|
||
# with tarfile.open(tar_filepath, "w:gz") as tar:
|
||
# tar.add(dir_path, arcname=os.path.basename(dir_path))
|
||
# s5cmd_syn_node_command = f's5cmd --log=error cp {tar_filepath} "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"'
|
||
# logger.info(s5cmd_syn_node_command)
|
||
# os.system(s5cmd_syn_node_command)
|
||
# logger.info(f"rm {tar_filepath}")
|
||
# os.remove(tar_filepath)
|
||
|
||
|
||
def sync_default_files():
|
||
try:
|
||
timestamp = str(int(time.time() * 1000))
|
||
prepare_version = PREPARE_ID if PREPARE_MODE == 'additional' else timestamp
|
||
need_prepare = True
|
||
prepare_type = 'default'
|
||
need_reboot = True
|
||
logger.info(f" sync custom nodes files")
|
||
# s5cmd_syn_node_command = f's5cmd --log=error sync --delete=true --exclude="*comfy_local_proxy.py" {DIR2}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"'
|
||
# logger.info(f"sync custom_nodes files start {s5cmd_syn_node_command}")
|
||
# os.system(s5cmd_syn_node_command)
|
||
compress_and_upload(f"{DIR2}", prepare_version)
|
||
logger.info(f" sync input files")
|
||
s5cmd_syn_input_command = f's5cmd --log=error sync --delete=true {DIR3}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/input/"'
|
||
logger.info(f"sync input files start {s5cmd_syn_input_command}")
|
||
os.system(s5cmd_syn_input_command)
|
||
logger.info(f" sync models files")
|
||
s5cmd_syn_model_command = f's5cmd --log=error sync --delete=true {DIR1}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/models/"'
|
||
logger.info(f"sync models files start {s5cmd_syn_model_command}")
|
||
os.system(s5cmd_syn_model_command)
|
||
logger.info(f"Files changed in:: {need_prepare} {DIR2} {DIR1} {DIR3}")
|
||
url = api_url + "prepare"
|
||
logger.info(f"URL:{url}")
|
||
data = {"endpoint_name": comfy_endpoint, "need_reboot": need_reboot, "prepare_id": prepare_version,
|
||
"prepare_type": prepare_type}
|
||
logger.info(f"prepare params Data: {json.dumps(data, indent=4)}")
|
||
result = subprocess.run(["curl", "--location", "--request", "POST", url, "--header",
|
||
f"x-api-key: {api_token}", "--data-raw", json.dumps(data)],
|
||
capture_output=True, text=True)
|
||
logger.info(result.stdout)
|
||
return result.stdout
|
||
except Exception as e:
|
||
logger.info(f"sync_files error {e}")
|
||
return None
|
||
|
||
|
||
def sync_files(filepath, is_folder, is_auto):
|
||
try:
|
||
directory = os.path.dirname(filepath)
|
||
logger.info(f"Directory changed in: {directory} {filepath}")
|
||
if not directory:
|
||
logger.info("root path no need to sync files by duplicate opt")
|
||
return None
|
||
logger.info(f"Files changed in: {filepath}")
|
||
timestamp = str(int(time.time() * 1000))
|
||
need_prepare = False
|
||
prepare_type = 'default'
|
||
need_reboot = False
|
||
for ignore_item in no_need_sync_files:
|
||
if filepath.endswith(ignore_item):
|
||
logger.info(f"no need to sync files by ignore files {filepath} ends by {ignore_item}")
|
||
return None
|
||
prepare_version = PREPARE_ID if PREPARE_MODE == 'additional' else timestamp
|
||
if (str(directory).endswith(f"{DIR2}" if DIR2.startswith("/") else f"/{DIR2}")
|
||
or str(filepath) == DIR2 or str(filepath) == f'./{DIR2}' or f"{DIR2}/" in filepath):
|
||
logger.info(f" sync custom nodes files: {filepath}")
|
||
s5cmd_syn_node_command = f's5cmd --log=error sync --delete=true --exclude="*comfy_local_proxy.py" {DIR2}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/custom_nodes/"'
|
||
# s5cmd_syn_node_command = f'aws s3 sync {DIR2}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"'
|
||
# s5cmd_syn_node_command = f's5cmd sync {DIR2}/* "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"'
|
||
|
||
# custom_node文件夹有变化 稍后再同步
|
||
if is_auto and not is_folder_unlocked(directory):
|
||
logger.info("sync custom_nodes files is changing ,waiting.... ")
|
||
return None
|
||
logger.info("sync custom_nodes files start")
|
||
logger.info(s5cmd_syn_node_command)
|
||
os.system(s5cmd_syn_node_command)
|
||
need_prepare = True
|
||
need_reboot = True
|
||
prepare_type = 'nodes'
|
||
elif (str(directory).endswith(f"{DIR3}" if DIR3.startswith("/") else f"/{DIR3}")
|
||
or str(filepath) == DIR3 or str(filepath) == f'./{DIR3}' or f"{DIR3}/" in filepath):
|
||
logger.info(f" sync input files: {filepath}")
|
||
s5cmd_syn_input_command = f's5cmd --log=error sync --delete=true {DIR3}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/input/"'
|
||
|
||
# 判断文件写完后再同步
|
||
if is_auto:
|
||
if bool(is_folder):
|
||
can_sync = is_folder_unlocked(filepath)
|
||
else:
|
||
can_sync = is_file_unlocked(filepath)
|
||
if not can_sync:
|
||
logger.info("sync input files is changing ,waiting.... ")
|
||
return None
|
||
logger.info("sync input files start")
|
||
logger.info(s5cmd_syn_input_command)
|
||
os.system(s5cmd_syn_input_command)
|
||
need_prepare = True
|
||
prepare_type = 'inputs'
|
||
elif (str(directory).endswith(f"{DIR1}" if DIR1.startswith("/") else f"/{DIR1}")
|
||
or str(filepath) == DIR1 or str(filepath) == f'./{DIR1}' or f"{DIR1}/" in filepath):
|
||
logger.info(f" sync models files: {filepath}")
|
||
s5cmd_syn_model_command = f's5cmd --log=error sync --delete=true {DIR1}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/models/"'
|
||
|
||
# 判断文件写完后再同步
|
||
if is_auto:
|
||
if bool(is_folder):
|
||
can_sync = is_folder_unlocked(filepath)
|
||
else:
|
||
can_sync = is_file_unlocked(filepath)
|
||
# logger.info(f'is folder {directory} {is_folder} can_sync {can_sync}')
|
||
if not can_sync:
|
||
logger.info("sync input models is changing ,waiting.... ")
|
||
return None
|
||
|
||
logger.info("sync models files start")
|
||
logger.info(s5cmd_syn_model_command)
|
||
os.system(s5cmd_syn_model_command)
|
||
need_prepare = True
|
||
prepare_type = 'models'
|
||
logger.info(f"Files changed in:: {need_prepare} {str(directory)} {DIR2} {DIR1} {DIR3}")
|
||
if need_prepare:
|
||
url = api_url + "prepare"
|
||
logger.info(f"URL:{url}")
|
||
data = {"endpoint_name": comfy_endpoint, "need_reboot": need_reboot, "prepare_id": prepare_version,
|
||
"prepare_type": prepare_type}
|
||
logger.info(f"prepare params Data: {json.dumps(data, indent=4)}")
|
||
result = subprocess.run(["curl", "--location", "--request", "POST", url, "--header",
|
||
f"x-api-key: {api_token}", "--data-raw", json.dumps(data)],
|
||
capture_output=True, text=True)
|
||
logger.info(result.stdout)
|
||
return result.stdout
|
||
return None
|
||
except Exception as e:
|
||
logger.info(f"sync_files error {e}")
|
||
return None
|
||
|
||
|
||
def is_folder_unlocked(directory):
|
||
# logger.info("check if folder ")
|
||
event_handler = MyHandlerWithCheck()
|
||
observer = Observer()
|
||
observer.schedule(event_handler, directory, recursive=True)
|
||
observer.start()
|
||
time.sleep(1)
|
||
result = False
|
||
try:
|
||
if event_handler.file_changed:
|
||
logger.info(f"folder {directory} is still changing..")
|
||
event_handler.file_changed = False
|
||
time.sleep(1)
|
||
if event_handler.file_changed:
|
||
logger.info(f"folder {directory} is still still changing..")
|
||
else:
|
||
logger.info(f"folder {directory} changing stopped")
|
||
result = True
|
||
else:
|
||
logger.info(f"folder {directory} not stopped")
|
||
result = True
|
||
except (KeyboardInterrupt, Exception) as e:
|
||
logger.info(f"folder {directory} changed exception {e}")
|
||
observer.stop()
|
||
return result
|
||
|
||
|
||
def is_file_unlocked(file_path):
|
||
# logger.info("check if file ")
|
||
try:
|
||
initial_size = os.path.getsize(file_path)
|
||
initial_mtime = os.path.getmtime(file_path)
|
||
time.sleep(1)
|
||
|
||
current_size = os.path.getsize(file_path)
|
||
current_mtime = os.path.getmtime(file_path)
|
||
if current_size != initial_size or current_mtime != initial_mtime:
|
||
logger.info(f"unlock file error {file_path} is changing")
|
||
return False
|
||
|
||
with open(file_path, 'r') as f:
|
||
fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||
return True
|
||
except (IOError, OSError, Exception) as e:
|
||
logger.info(f"unlock file error {file_path} is writing")
|
||
logger.error(e)
|
||
return False
|
||
|
||
|
||
class MyHandlerWithCheck(FileSystemEventHandler):
|
||
def __init__(self):
|
||
self.file_changed = False
|
||
|
||
def on_modified(self, event):
|
||
logger.info(f"custom_node folder is changing {event.src_path}")
|
||
self.file_changed = True
|
||
|
||
def on_deleted(self, event):
|
||
logger.info(f"custom_node folder is changing {event.src_path}")
|
||
self.file_changed = True
|
||
|
||
def on_created(self, event):
|
||
logger.info(f"custom_node folder is changing {event.src_path}")
|
||
self.file_changed = True
|
||
|
||
|
||
class MyHandlerWithSync(FileSystemEventHandler):
|
||
def on_modified(self, event):
|
||
logger.info(f"{datetime.datetime.now()} files modified ,start to sync {event}")
|
||
sync_files(event.src_path, event.is_directory, True)
|
||
|
||
def on_created(self, event):
|
||
logger.info(f"{datetime.datetime.now()} files added ,start to sync {event}")
|
||
sync_files(event.src_path, event.is_directory, True)
|
||
|
||
def on_deleted(self, event):
|
||
logger.info(f"{datetime.datetime.now()} files deleted ,start to sync {event}")
|
||
sync_files(event.src_path, event.is_directory, True)
|
||
|
||
|
||
stop_event = threading.Event()
|
||
|
||
|
||
def check_and_sync():
|
||
logger.info("check_and_sync start")
|
||
event_handler = MyHandlerWithSync()
|
||
observer = Observer()
|
||
try:
|
||
observer.schedule(event_handler, DIR1, recursive=True)
|
||
observer.schedule(event_handler, DIR2, recursive=True)
|
||
observer.schedule(event_handler, DIR3, recursive=True)
|
||
observer.start()
|
||
while True:
|
||
time.sleep(1)
|
||
except KeyboardInterrupt:
|
||
logger.info("sync Shutting down please restart ComfyUI")
|
||
observer.stop()
|
||
observer.join()
|
||
|
||
|
||
def signal_handler(sig, frame):
|
||
logger.info("Received termination signal. Exiting...")
|
||
stop_event.set()
|
||
|
||
|
||
if os.environ.get('DISABLE_AUTO_SYNC') == 'false':
|
||
signal.signal(signal.SIGINT, signal_handler)
|
||
signal.signal(signal.SIGTERM, signal_handler)
|
||
|
||
check_sync_thread = threading.Thread(target=check_and_sync)
|
||
check_sync_thread.start()
|
||
|
||
@server.PromptServer.instance.routes.post('/map_release')
|
||
async def map_release(request):
|
||
logger.info(f"start to map_release {request}")
|
||
json_data = await request.json()
|
||
if (not json_data or 'clientId' not in json_data or not json_data.get('clientId')
|
||
or 'releaseVersion' not in json_data or not json_data.get('releaseVersion')):
|
||
return web.Response(status=500, content_type='application/json', body=json.dumps({"result": False}))
|
||
global client_release_map
|
||
client_release_map[json_data.get('clientId')] = json_data.get('releaseVersion')
|
||
logger.info(f"client_release_map :{client_release_map}")
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True}))
|
||
|
||
|
||
@server.PromptServer.instance.routes.get("/reboot")
|
||
async def restart(self):
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "reboot is not allowed during sync workflow"}))
|
||
|
||
if not is_master_process:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "only master can restart"}))
|
||
logger.info(f"start to reboot {self}")
|
||
try:
|
||
from xmlrpc.client import ServerProxy
|
||
server = ServerProxy('http://localhost:9001/RPC2')
|
||
server.supervisor.restart()
|
||
# server.supervisor.shutdown()
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True}))
|
||
except Exception as e:
|
||
logger.info(f"error reboot {e}")
|
||
pass
|
||
return os.execv(sys.executable, [sys.executable] + sys.argv)
|
||
|
||
|
||
@server.PromptServer.instance.routes.get("/check_prepare")
|
||
async def check_prepare(self):
|
||
logger.info(f"start to check_prepare {self}")
|
||
try:
|
||
get_response = requests.get(f"{api_url}/prepare/{comfy_endpoint}", headers=headers)
|
||
response = get_response.json()
|
||
logger.info(f"check sync response is {response}")
|
||
if get_response.status_code == 200 and response['data']['prepareSuccess']:
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True}))
|
||
else:
|
||
logger.info(f"check sync response is {response} {response['data']['prepareSuccess']}")
|
||
return web.Response(status=500, content_type='application/json', body=json.dumps({"result": False}))
|
||
except Exception as e:
|
||
logger.info(f"error restart {e}")
|
||
pass
|
||
return os.execv(sys.executable, [sys.executable] + sys.argv)
|
||
|
||
|
||
@server.PromptServer.instance.routes.get("/gc")
|
||
async def gc(self):
|
||
logger.info(f"start to gc {self}")
|
||
try:
|
||
logger.info(f"gc start: {time.time()}")
|
||
server_instance = server.PromptServer.instance
|
||
e = execution.PromptExecutor(server_instance)
|
||
e.reset()
|
||
comfy.model_management.cleanup_models()
|
||
gc.collect()
|
||
comfy.model_management.soft_empty_cache()
|
||
gc_triggered = True
|
||
logger.info(f"gc end: {time.time()}")
|
||
except Exception as e:
|
||
logger.info(f"error restart {e}")
|
||
pass
|
||
return os.execv(sys.executable, [sys.executable] + sys.argv)
|
||
|
||
|
||
def is_action_lock():
|
||
lock_file = f'/container/sync_lock'
|
||
if os.path.exists(lock_file):
|
||
with open(lock_file, 'r') as f:
|
||
content = f.read()
|
||
if content and not check_workflow_exists(content):
|
||
return True
|
||
return False
|
||
|
||
|
||
def action_lock(name: str):
|
||
lock_file = f'/container/sync_lock'
|
||
with open(lock_file, 'w') as f:
|
||
f.write(name)
|
||
|
||
|
||
def action_unlock():
|
||
lock_file = f'/container/sync_lock'
|
||
with open(lock_file, 'w') as f:
|
||
f.write("")
|
||
|
||
|
||
@server.PromptServer.instance.routes.get("/restart")
|
||
async def restart(self):
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "restart is not allowed during sync workflow"}))
|
||
|
||
logger.info(f"start to restart {self}")
|
||
try:
|
||
sys.stdout.close_log()
|
||
except Exception as e:
|
||
logger.info(f"error restart {e}")
|
||
pass
|
||
return os.execv(sys.executable, [sys.executable] + sys.argv)
|
||
|
||
|
||
@server.PromptServer.instance.routes.get("/sync_env")
|
||
async def sync_env(request):
|
||
logger.info(f"start to sync_env {request}")
|
||
try:
|
||
result = sync_default_files()
|
||
logger.debug(f"sync result is :{result}")
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True}))
|
||
except Exception as e:
|
||
logger.info(f"error sync_env {e}")
|
||
pass
|
||
return web.Response(status=500, content_type='application/json', body=json.dumps({"result": False}))
|
||
|
||
|
||
@server.PromptServer.instance.routes.post("/change_env")
|
||
async def change_env(request):
|
||
logger.info(f"start to change_env {request}")
|
||
json_data = await request.json()
|
||
if DISABLE_AWS_PROXY in json_data and json_data[DISABLE_AWS_PROXY] is not None:
|
||
logger.info(f"origin evn key DISABLE_AWS_PROXY is :{os.environ.get(DISABLE_AWS_PROXY)} {str(json_data[DISABLE_AWS_PROXY])}")
|
||
os.environ[DISABLE_AWS_PROXY] = str(json_data[DISABLE_AWS_PROXY])
|
||
logger.info(f"now evn key DISABLE_AWS_PROXY is :{os.environ.get(DISABLE_AWS_PROXY)}")
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True}))
|
||
|
||
|
||
@server.PromptServer.instance.routes.get("/get_env")
|
||
async def get_env(request):
|
||
env = os.environ.get(DISABLE_AWS_PROXY, 'False')
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps({"env": env}))
|
||
|
||
def get_directory_size(directory):
|
||
total_size = 0
|
||
for dirpath, dirnames, filenames in os.walk(directory):
|
||
for filename in filenames:
|
||
filepath = os.path.join(dirpath, filename)
|
||
if not os.path.islink(filepath): # 检查文件是否不是符号链接
|
||
total_size += os.path.getsize(filepath)
|
||
return total_size
|
||
|
||
@server.PromptServer.instance.routes.post("/workflows")
|
||
async def release_workflow(request):
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "release is not allowed during sync workflow"}))
|
||
|
||
if not is_master_process:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "only master can release workflow"}))
|
||
|
||
logger.info(f"start to release workflow {request}")
|
||
try:
|
||
json_data = await request.json()
|
||
if 'name' not in json_data or not json_data['name']:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"name is required"}))
|
||
|
||
workflow_name = json_data['name']
|
||
if workflow_name == 'default':
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"{workflow_name} is not allowed"}))
|
||
|
||
payload_json = ''
|
||
|
||
if 'payload_json' in json_data:
|
||
payload_json = json_data['payload_json']
|
||
|
||
if check_workflow_exists(workflow_name):
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"{workflow_name} already exists"}))
|
||
|
||
start_time = time.time()
|
||
|
||
action_lock(workflow_name)
|
||
|
||
base_image = os.getenv('BASE_IMAGE')
|
||
subprocess.check_output(f"echo {workflow_name} > /container/image_target_name", shell=True)
|
||
subprocess.check_output(f"echo {base_image} > /container/image_base", shell=True)
|
||
|
||
cur_workflow_name = os.getenv('WORKFLOW_NAME')
|
||
source_path = f"/container/workflows/{cur_workflow_name}"
|
||
|
||
print(f"source_path is {source_path}")
|
||
total_size_bytes = get_directory_size(source_path)
|
||
source_size = round(total_size_bytes / (1024 ** 3), 2)
|
||
|
||
s5cmd_sync_command = (f's5cmd sync '
|
||
f'--delete=true '
|
||
f'--exclude="*comfy.tar" '
|
||
f'--exclude="*.log" '
|
||
f'--exclude="*__pycache__*" '
|
||
f'--exclude="*/ComfyUI/input/*" '
|
||
f'--exclude="*/ComfyUI/output/*" '
|
||
f'"{source_path}/*" '
|
||
f'"s3://{bucket_name}/comfy/workflows/{workflow_name}/"')
|
||
|
||
s5cmd_lock_command = (f'echo "lock" > lock && '
|
||
f's5cmd sync lock s3://{bucket_name}/comfy/workflows/{workflow_name}/lock')
|
||
|
||
logger.info(f"sync workflows files start {s5cmd_sync_command}")
|
||
|
||
subprocess.check_output(s5cmd_sync_command, shell=True)
|
||
subprocess.check_output(s5cmd_lock_command, shell=True)
|
||
|
||
end_time = time.time()
|
||
cost_time = end_time - start_time
|
||
image_hash = os.getenv('IMAGE_HASH')
|
||
image_uri = f"{image_hash}:{workflow_name}"
|
||
|
||
data = {
|
||
"payload_json": payload_json,
|
||
"image_uri": image_uri,
|
||
"name": workflow_name,
|
||
"size": str(source_size),
|
||
}
|
||
get_response = requests.post(f"{api_url}/workflows", headers=headers, data=json.dumps(data))
|
||
response = get_response.json()
|
||
logger.info(f"release workflow response is {response}")
|
||
action_unlock()
|
||
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "message": "success", "cost_time": cost_time}))
|
||
except Exception as e:
|
||
logger.info(e)
|
||
action_unlock()
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": 'Release workflow failed'}))
|
||
|
||
|
||
@server.PromptServer.instance.routes.put("/workflows")
|
||
async def switch_workflow(request):
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "switch is not allowed during sync workflow"}))
|
||
|
||
try:
|
||
json_data = await request.json()
|
||
if 'name' not in json_data or not json_data['name']:
|
||
raise ValueError("name is required")
|
||
|
||
workflow_name = json_data['name']
|
||
|
||
if workflow_name == os.getenv('WORKFLOW_NAME'):
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "workflow is already in use"}))
|
||
|
||
if workflow_name == 'default' and not is_master_process:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "slave can not use default workflow"}))
|
||
|
||
if workflow_name != 'default':
|
||
if not check_workflow_exists(workflow_name):
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"{workflow_name} not exists"}))
|
||
|
||
name_file = os.getenv('WORKFLOW_NAME_FILE')
|
||
subprocess.check_output(f"echo {workflow_name} > {name_file}", shell=True)
|
||
subprocess.run(["pkill", "-f", "python3"])
|
||
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "message": "Please wait to restart"}))
|
||
except Exception as e:
|
||
logger.info(e)
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": 'Switch workflow failed'}))
|
||
|
||
|
||
@server.PromptServer.instance.routes.get("/workflows")
|
||
async def get_workflows(request):
|
||
try:
|
||
workflow_name = os.getenv('WORKFLOW_NAME')
|
||
|
||
response = requests.get(f"{api_url}/workflows", headers=headers, params={"limit": 1000})
|
||
if response.status_code != 200:
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": 'Get workflows failed'}))
|
||
|
||
data = response.json()['data']
|
||
workflows = data['workflows']
|
||
list = []
|
||
for workflow in workflows:
|
||
if workflow['status'] != 'Enabled':
|
||
continue
|
||
list.append({
|
||
"name": workflow['name'],
|
||
"size": workflow['size'],
|
||
"payload_json": workflow['payload_json'],
|
||
"in_use": workflow['name'] == workflow_name
|
||
})
|
||
|
||
data['current_workflow'] = workflow_name
|
||
data['workflows'] = list
|
||
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "data": data}))
|
||
except Exception as e:
|
||
logger.info(e)
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": 'Switch workflow failed'}))
|
||
|
||
|
||
def check_workflow_exists(name: str):
|
||
get_response = requests.get(f"{api_url}/workflows/{name}", headers=headers)
|
||
return get_response.status_code == 200
|
||
|
||
|
||
def restore_commands():
|
||
subprocess.run(["sleep", "5"])
|
||
subprocess.run(["pkill", "-f", "python3"])
|
||
|
||
@server.PromptServer.instance.routes.post("/restore")
|
||
async def release_rebuild_workflow(request):
|
||
if os.getenv('WORKFLOW_NAME') != 'default':
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "only default workflow can be restored"}))
|
||
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "restore is not allowed during sync workflow"}))
|
||
|
||
if not is_master_process:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "only master can restore comfy"}))
|
||
|
||
logger.info(f"start to restore EC2 {request}")
|
||
|
||
try:
|
||
os.system("mv /container/workflows/default/ComfyUI/models/checkpoints/v1-5-pruned-emaonly.ckpt /")
|
||
os.system("mv /container/workflows/default/ComfyUI/models/animatediff_models/mm_sd_v15_v2.ckpt /")
|
||
os.system("rm -rf /container/workflows/default")
|
||
os.system("mkdir -p /container/workflows/default")
|
||
os.system("tar --overwrite -xf /container/default.tar -C /container/workflows/default/")
|
||
os.system("mkdir -p /container/workflows/default/ComfyUI/models/checkpoints")
|
||
os.system("mkdir -p /container/workflows/default/ComfyUI/models/animatediff_models")
|
||
os.system("mv /v1-5-pruned-emaonly.ckpt /container/workflows/default/ComfyUI/models/checkpoints/")
|
||
os.system("mv /mm_sd_v15_v2.ckpt /container/workflows/default/ComfyUI/models/animatediff_models/")
|
||
thread = threading.Thread(target=restore_commands)
|
||
thread.start()
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "message": "comfy will be restored in 5 seconds"}))
|
||
except Exception as e:
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": e}))
|
||
|
||
|
||
if is_on_sagemaker:
|
||
|
||
global need_sync
|
||
global prompt_id
|
||
global executing
|
||
executing = False
|
||
|
||
global reboot
|
||
reboot = False
|
||
|
||
global last_call_time
|
||
last_call_time = None
|
||
global gc_triggered
|
||
gc_triggered = False
|
||
|
||
REGION = os.environ.get('AWS_REGION')
|
||
BUCKET = os.environ.get('S3_BUCKET_NAME')
|
||
QUEUE_URL = os.environ.get('COMFY_QUEUE_URL')
|
||
|
||
GEN_INSTANCE_ID = os.environ.get('ENDPOINT_INSTANCE_ID') if 'ENDPOINT_INSTANCE_ID' in os.environ and os.environ.get('ENDPOINT_INSTANCE_ID') else str(uuid.uuid4())
|
||
ENDPOINT_NAME = os.environ.get('ENDPOINT_NAME')
|
||
ENDPOINT_ID = os.environ.get('ENDPOINT_ID')
|
||
|
||
INSTANCE_MONITOR_TABLE_NAME = os.environ.get('COMFY_INSTANCE_MONITOR_TABLE')
|
||
SYNC_TABLE_NAME = os.environ.get('COMFY_SYNC_TABLE')
|
||
|
||
dynamodb = boto3.resource('dynamodb', region_name=REGION)
|
||
sync_table = dynamodb.Table(SYNC_TABLE_NAME)
|
||
instance_monitor_table = dynamodb.Table(INSTANCE_MONITOR_TABLE_NAME)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.INFO)
|
||
|
||
ROOT_PATH = '/home/ubuntu/ComfyUI'
|
||
sqs_client = boto3.client('sqs', region_name=REGION)
|
||
|
||
GC_WAIT_TIME = 1800
|
||
|
||
|
||
def print_env():
|
||
for key, value in os.environ.items():
|
||
logger.info(f"{key}: {value}")
|
||
|
||
|
||
@dataclass
|
||
class ComfyResponse:
|
||
statusCode: int
|
||
message: str
|
||
body: Optional[dict]
|
||
|
||
|
||
def ok(body: dict):
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps(body))
|
||
|
||
|
||
def error(body: dict):
|
||
# TODO 500 -》200 because of need resp anyway not exception
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps(body))
|
||
|
||
|
||
def sen_sqs_msg(message_body, prompt_id_key):
|
||
response = sqs_client.send_message(
|
||
QueueUrl=QUEUE_URL,
|
||
MessageBody=json.dumps(message_body),
|
||
MessageGroupId=prompt_id_key
|
||
)
|
||
message_id = response['MessageId']
|
||
return message_id
|
||
|
||
|
||
def sen_finish_sqs_msg(prompt_id_key):
|
||
global need_sync
|
||
# logger.info(f"sen_finish_sqs_msg start... {need_sync},{prompt_id_key}")
|
||
if need_sync and QUEUE_URL and REGION:
|
||
message_body = {'prompt_id': prompt_id_key, 'event': 'finish', 'data': {"node": None, "prompt_id": prompt_id_key},
|
||
'sid': None}
|
||
message_id = sen_sqs_msg(message_body, prompt_id_key)
|
||
logger.info(f"finish message sent {message_id}")
|
||
|
||
|
||
async def prepare_comfy_env(sync_item: dict):
|
||
try:
|
||
request_id = sync_item['request_id']
|
||
logger.info(f"prepare_environment start sync_item:{sync_item}")
|
||
prepare_type = sync_item['prepare_type']
|
||
rlt = True
|
||
if prepare_type in ['default', 'models']:
|
||
sync_models_rlt = sync_s3_files_or_folders_to_local(f'{request_id}/models/*', f'{ROOT_PATH}/models', False)
|
||
if not sync_models_rlt:
|
||
rlt = False
|
||
if prepare_type in ['default', 'inputs']:
|
||
sync_inputs_rlt = sync_s3_files_or_folders_to_local(f'{request_id}/input/*', f'{ROOT_PATH}/input', False)
|
||
if not sync_inputs_rlt:
|
||
rlt = False
|
||
if prepare_type in ['default', 'nodes']:
|
||
sync_nodes_rlt = sync_s3_files_or_folders_to_local(f'{request_id}/custom_nodes/*',
|
||
f'{ROOT_PATH}/custom_nodes', True)
|
||
if not sync_nodes_rlt:
|
||
rlt = False
|
||
if prepare_type == 'custom':
|
||
sync_source_path = sync_item['s3_source_path']
|
||
local_target_path = sync_item['local_target_path']
|
||
if not sync_source_path or not local_target_path:
|
||
logger.info("s3_source_path and local_target_path should not be empty")
|
||
else:
|
||
sync_rlt = sync_s3_files_or_folders_to_local(sync_source_path,
|
||
f'{ROOT_PATH}/{local_target_path}', False)
|
||
if not sync_rlt:
|
||
rlt = False
|
||
elif prepare_type == 'other':
|
||
sync_script = sync_item['sync_script']
|
||
logger.info("sync_script")
|
||
# sync_script.startswith('s5cmd') 不允许
|
||
try:
|
||
if sync_script and (sync_script.startswith("python3 -m pip") or sync_script.startswith("python -m pip")
|
||
or sync_script.startswith("pip install") or sync_script.startswith("apt")
|
||
or sync_script.startswith("os.environ") or sync_script.startswith("ls")
|
||
or sync_script.startswith("env") or sync_script.startswith("source")
|
||
or sync_script.startswith("curl") or sync_script.startswith("wget")
|
||
or sync_script.startswith("print") or sync_script.startswith("cat")
|
||
or sync_script.startswith("sudo chmod") or sync_script.startswith("chmod")
|
||
or sync_script.startswith("/home/ubuntu/ComfyUI/venv/bin/python")):
|
||
os.system(sync_script)
|
||
elif sync_script and (sync_script.startswith("export ") and len(sync_script.split(" ")) > 2):
|
||
sync_script_key = sync_script.split(" ")[1]
|
||
sync_script_value = sync_script.split(" ")[2]
|
||
os.environ[sync_script_key] = sync_script_value
|
||
logger.info(os.environ.get(sync_script_key))
|
||
except Exception as e:
|
||
logger.error(f"Exception while execute sync_scripts : {sync_script}")
|
||
rlt = False
|
||
need_reboot = True if ('need_reboot' in sync_item and sync_item['need_reboot']
|
||
and str(sync_item['need_reboot']).lower() == 'true')else False
|
||
global reboot
|
||
reboot = need_reboot
|
||
if need_reboot:
|
||
os.environ['NEED_REBOOT'] = 'true'
|
||
else:
|
||
os.environ['NEED_REBOOT'] = 'false'
|
||
logger.info("prepare_environment end")
|
||
os.environ['LAST_SYNC_REQUEST_ID'] = sync_item['request_id']
|
||
os.environ['LAST_SYNC_REQUEST_TIME'] = str(sync_item['request_time'])
|
||
return rlt
|
||
except Exception as e:
|
||
return False
|
||
|
||
|
||
def sync_s3_files_or_folders_to_local(s3_path, local_path, need_un_tar):
|
||
logger.info("sync_s3_models_or_inputs_to_local start")
|
||
# s5cmd_command = f'{ROOT_PATH}/tools/s5cmd sync "s3://{bucket_name}/{s3_path}/*" "{local_path}/"'
|
||
if need_un_tar:
|
||
s5cmd_command = f's5cmd sync "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
|
||
else:
|
||
s5cmd_command = f's5cmd sync --delete=true "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
|
||
# s5cmd_command = f's5cmd sync --delete=true "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
|
||
# s5cmd_command = f's5cmd sync "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
|
||
try:
|
||
logger.info(s5cmd_command)
|
||
os.system(s5cmd_command)
|
||
logger.info(f'Files copied from "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" to "{local_path}/"')
|
||
if need_un_tar:
|
||
for filename in os.listdir(local_path):
|
||
if filename.endswith(".tar.gz"):
|
||
tar_filepath = os.path.join(local_path, filename)
|
||
# extract_path = os.path.splitext(os.path.splitext(tar_filepath)[0])[0]
|
||
# os.makedirs(extract_path, exist_ok=True)
|
||
# logger.info(f'Extracting extract_path is {extract_path}')
|
||
|
||
with tarfile.open(tar_filepath, "r:gz") as tar:
|
||
for member in tar.getmembers():
|
||
tar.extract(member, path=local_path)
|
||
os.remove(tar_filepath)
|
||
logger.info(f'File {tar_filepath} extracted and removed')
|
||
return True
|
||
except Exception as e:
|
||
logger.info(f"Error executing s5cmd command: {e}")
|
||
return False
|
||
|
||
|
||
def sync_local_outputs_to_s3(s3_path, local_path):
|
||
logger.info("sync_local_outputs_to_s3 start")
|
||
s5cmd_command = f's5cmd sync "{local_path}/*" "s3://{BUCKET}/comfy/{s3_path}/" '
|
||
try:
|
||
logger.info(s5cmd_command)
|
||
os.system(s5cmd_command)
|
||
logger.info(f'Files copied local to "s3://{BUCKET}/comfy/{s3_path}/" to "{local_path}/"')
|
||
clean_cmd = f'rm -rf {local_path}'
|
||
os.system(clean_cmd)
|
||
logger.info(f'Files removed from local {local_path}')
|
||
except Exception as e:
|
||
logger.info(f"Error executing s5cmd command: {e}")
|
||
|
||
|
||
def sync_local_outputs_to_base64(local_path):
|
||
logger.info("sync_local_outputs_to_base64 start")
|
||
try:
|
||
result = {}
|
||
for root, dirs, files in os.walk(local_path):
|
||
for file in files:
|
||
file_path = os.path.join(root, file)
|
||
with open(file_path, "rb") as f:
|
||
file_content = f.read()
|
||
base64_content = base64.b64encode(file_content).decode('utf-8')
|
||
result[file] = base64_content
|
||
clean_cmd = f'rm -rf {local_path}'
|
||
os.system(clean_cmd)
|
||
logger.info(f'Files removed from local {local_path}')
|
||
return result
|
||
except Exception as e:
|
||
logger.info(f"Error executing s5cmd command: {e}")
|
||
return {}
|
||
|
||
|
||
@server.PromptServer.instance.routes.post("/execute_proxy")
|
||
async def execute_proxy(request):
|
||
json_data = await request.json()
|
||
if 'out_path' in json_data and json_data['out_path'] is not None:
|
||
out_path = json_data['out_path']
|
||
else:
|
||
out_path = None
|
||
logger.info(f"invocations start json_data:{json_data}")
|
||
global need_sync
|
||
need_sync = json_data["need_sync"]
|
||
global prompt_id
|
||
prompt_id = json_data["prompt_id"]
|
||
try:
|
||
global executing
|
||
if executing is True:
|
||
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
|
||
"message": "the environment is not ready valid[0] is false, need to resync"}
|
||
sen_finish_sqs_msg(prompt_id)
|
||
return error(resp)
|
||
executing = True
|
||
logger.info(
|
||
f'bucket_name: {BUCKET}, region: {REGION}')
|
||
if ('need_prepare' in json_data and json_data['need_prepare']
|
||
and 'prepare_props' in json_data and json_data['prepare_props']):
|
||
sync_already = await prepare_comfy_env(json_data['prepare_props'])
|
||
if not sync_already:
|
||
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
|
||
"message": "the environment is not ready with sync"}
|
||
executing = False
|
||
sen_finish_sqs_msg(prompt_id)
|
||
return error(resp)
|
||
server_instance = server.PromptServer.instance
|
||
if "number" in json_data:
|
||
number = float(json_data['number'])
|
||
server_instance.number = number
|
||
else:
|
||
number = server_instance.number
|
||
if "front" in json_data:
|
||
if json_data['front']:
|
||
number = -number
|
||
server_instance.number += 1
|
||
valid = execution.validate_prompt(json_data['prompt'])
|
||
logger.info(f"Validating prompt result is {valid}")
|
||
if not valid[0]:
|
||
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
|
||
"message": "the environment is not ready valid[0] is false, need to resync"}
|
||
executing = False
|
||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||
sen_finish_sqs_msg(prompt_id)
|
||
return error(resp)
|
||
# if len(valid) == 4 and len(valid[3]) > 0:
|
||
# logger.info(f"Validating prompt error there is something error because of :valid: {valid}")
|
||
# resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
|
||
# "message": f"the valid is error, need to resync or check the workflow :{valid}"}
|
||
# executing = False
|
||
# return error(resp)
|
||
extra_data = {}
|
||
client_id = ''
|
||
if "extra_data" in json_data:
|
||
extra_data = json_data["extra_data"]
|
||
if 'client_id' in extra_data and extra_data['client_id']:
|
||
client_id = extra_data['client_id']
|
||
if "client_id" in json_data and json_data["client_id"]:
|
||
extra_data["client_id"] = json_data["client_id"]
|
||
client_id = json_data["client_id"]
|
||
|
||
server_instance.client_id = client_id
|
||
|
||
prompt_id = json_data['prompt_id']
|
||
server_instance.last_prompt_id = prompt_id
|
||
e = execution.PromptExecutor(server_instance)
|
||
outputs_to_execute = valid[2]
|
||
e.execute(json_data['prompt'], prompt_id, extra_data, outputs_to_execute)
|
||
|
||
s3_out_path = f'output/{prompt_id}/{out_path}' if out_path is not None else f'output/{prompt_id}'
|
||
s3_temp_path = f'temp/{prompt_id}/{out_path}' if out_path is not None else f'temp/{prompt_id}'
|
||
local_out_path = f'{ROOT_PATH}/output/{out_path}' if out_path is not None else f'{ROOT_PATH}/output'
|
||
local_temp_path = f'{ROOT_PATH}/temp/{out_path}' if out_path is not None else f'{ROOT_PATH}/temp'
|
||
|
||
logger.info(f"s3_out_path is {s3_out_path} and s3_temp_path is {s3_temp_path} and local_out_path is {local_out_path} and local_temp_path is {local_temp_path}")
|
||
|
||
sync_local_outputs_to_s3(s3_out_path, local_out_path)
|
||
sync_local_outputs_to_s3(s3_temp_path, local_temp_path)
|
||
|
||
response_body = {
|
||
"prompt_id": prompt_id,
|
||
"instance_id": GEN_INSTANCE_ID,
|
||
"status": "success",
|
||
"output_path": f's3://{BUCKET}/comfy/{s3_out_path}',
|
||
"temp_path": f's3://{BUCKET}/comfy/{s3_temp_path}',
|
||
}
|
||
sen_finish_sqs_msg(prompt_id)
|
||
logger.info(f"execute inference response is {response_body}")
|
||
executing = False
|
||
return ok(response_body)
|
||
except Exception as ecp:
|
||
logger.info(f"exception occurred {ecp}")
|
||
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
|
||
"message": f"exception occurred {ecp}"}
|
||
executing = False
|
||
return error(resp)
|
||
finally:
|
||
logger.info(f"gc check: {time.time()}")
|
||
try:
|
||
global last_call_time, gc_triggered
|
||
gc_triggered = False
|
||
if last_call_time is None:
|
||
logger.info(f"gc check last time is NONE")
|
||
last_call_time = time.time()
|
||
else:
|
||
if time.time() - last_call_time > GC_WAIT_TIME:
|
||
if not gc_triggered:
|
||
logger.info(f"gc start: {time.time()} - {last_call_time}")
|
||
e.reset()
|
||
comfy.model_management.cleanup_models()
|
||
gc.collect()
|
||
comfy.model_management.soft_empty_cache()
|
||
gc_triggered = True
|
||
logger.info(f"gc end: {time.time()} - {last_call_time}")
|
||
last_call_time = time.time()
|
||
else:
|
||
last_call_time = time.time()
|
||
logger.info(f"gc check end: {time.time()}")
|
||
except Exception as e:
|
||
logger.info(f"gc error: {e}")
|
||
|
||
|
||
def get_last_ddb_sync_record():
|
||
sync_response = sync_table.query(
|
||
KeyConditionExpression=Key('endpoint_name').eq(ENDPOINT_NAME),
|
||
Limit=1,
|
||
ScanIndexForward=False
|
||
)
|
||
latest_sync_record = sync_response['Items'][0] if ('Items' in sync_response
|
||
and len(sync_response['Items']) > 0) else None
|
||
if latest_sync_record:
|
||
logger.info(f"latest_sync_record is:{latest_sync_record}")
|
||
return latest_sync_record
|
||
|
||
logger.info("no latest_sync_record found")
|
||
return None
|
||
|
||
|
||
def get_latest_ddb_instance_monitor_record():
|
||
key_condition_expression = ('endpoint_name = :endpoint_name_val '
|
||
'AND gen_instance_id = :gen_instance_id_val')
|
||
expression_attribute_values = {
|
||
':endpoint_name_val': ENDPOINT_NAME,
|
||
':gen_instance_id_val': GEN_INSTANCE_ID
|
||
}
|
||
instance_monitor_response = instance_monitor_table.query(
|
||
KeyConditionExpression=key_condition_expression,
|
||
ExpressionAttributeValues=expression_attribute_values
|
||
)
|
||
instance_monitor_record = instance_monitor_response['Items'][0] \
|
||
if ('Items' in instance_monitor_response and len(instance_monitor_response['Items']) > 0) else None
|
||
|
||
if instance_monitor_record:
|
||
logger.info(f"instance_monitor_record is {instance_monitor_record}")
|
||
return instance_monitor_record
|
||
|
||
logger.info("no instance_monitor_record found")
|
||
return None
|
||
|
||
|
||
def save_sync_instance_monitor(last_sync_request_id: str, sync_status: str):
|
||
item = {
|
||
'endpoint_id': ENDPOINT_ID,
|
||
'endpoint_name': ENDPOINT_NAME,
|
||
'gen_instance_id': GEN_INSTANCE_ID,
|
||
'sync_status': sync_status,
|
||
'last_sync_request_id': last_sync_request_id,
|
||
'last_sync_time': datetime.datetime.now().isoformat(),
|
||
'sync_list': [],
|
||
'create_time': datetime.datetime.now().isoformat(),
|
||
'last_heartbeat_time': datetime.datetime.now().isoformat()
|
||
}
|
||
save_resp = instance_monitor_table.put_item(Item=item)
|
||
logger.info(f"save instance item {save_resp}")
|
||
return save_resp
|
||
|
||
|
||
def update_sync_instance_monitor(instance_monitor_record):
|
||
# 更新记录
|
||
update_expression = ("SET sync_status = :new_sync_status, last_sync_request_id = :sync_request_id, "
|
||
"sync_list = :sync_list, last_sync_time = :sync_time, last_heartbeat_time = :heartbeat_time")
|
||
expression_attribute_values = {
|
||
":new_sync_status": instance_monitor_record['sync_status'],
|
||
":sync_request_id": instance_monitor_record['last_sync_request_id'],
|
||
":sync_list": instance_monitor_record['sync_list'],
|
||
":sync_time": datetime.datetime.now().isoformat(),
|
||
":heartbeat_time": datetime.datetime.now().isoformat(),
|
||
}
|
||
|
||
response = instance_monitor_table.update_item(
|
||
Key={'endpoint_name': ENDPOINT_NAME,
|
||
'gen_instance_id': GEN_INSTANCE_ID},
|
||
UpdateExpression=update_expression,
|
||
ExpressionAttributeValues=expression_attribute_values
|
||
)
|
||
logger.info(f"update_sync_instance_monitor :{response}")
|
||
return response
|
||
|
||
|
||
def sync_instance_monitor_status(need_save: bool):
|
||
try:
|
||
logger.info(f"sync_instance_monitor_status {datetime.datetime.now()}")
|
||
if need_save:
|
||
save_sync_instance_monitor('', 'init')
|
||
else:
|
||
update_expression = ("SET last_heartbeat_time = :heartbeat_time")
|
||
expression_attribute_values = {
|
||
":heartbeat_time": datetime.datetime.now().isoformat(),
|
||
}
|
||
instance_monitor_table.update_item(
|
||
Key={'endpoint_name': ENDPOINT_NAME,
|
||
'gen_instance_id': GEN_INSTANCE_ID},
|
||
UpdateExpression=update_expression,
|
||
ExpressionAttributeValues=expression_attribute_values
|
||
)
|
||
except Exception as e:
|
||
logger.info(f"sync_instance_monitor_status error :{e}")
|
||
|
||
|
||
@server.PromptServer.instance.routes.post("/reboot")
|
||
async def restart(self):
|
||
logger.debug(f"start to reboot!!!!!!!! {self}")
|
||
global executing
|
||
if executing is True:
|
||
logger.info(f"other inference doing cannot reboot!!!!!!!!")
|
||
return ok({"message": "other inference doing cannot reboot"})
|
||
need_reboot = os.environ.get('NEED_REBOOT')
|
||
if need_reboot and need_reboot.lower() != 'true':
|
||
logger.info("no need to reboot by os")
|
||
return ok({"message": "no need to reboot by os"})
|
||
global reboot
|
||
if reboot is False:
|
||
logger.info("no need to reboot by global constant")
|
||
return ok({"message": "no need to reboot by constant"})
|
||
|
||
logger.debug("rebooting !!!!!!!!")
|
||
try:
|
||
sys.stdout.close_log()
|
||
except Exception as e:
|
||
logger.info(f"error reboot!!!!!!!! {e}")
|
||
pass
|
||
return os.execv(sys.executable, [sys.executable] + sys.argv)
|
||
|
||
|
||
# must be sync invoke and use the env to check
|
||
@server.PromptServer.instance.routes.post("/sync_instance")
|
||
async def sync_instance(request):
|
||
if not BUCKET:
|
||
logger.error("No bucket provided ,wait and try again")
|
||
resp = {"status": "success", "message": "syncing"}
|
||
return ok(resp)
|
||
|
||
if 'ALREADY_SYNC' in os.environ and os.environ.get('ALREADY_SYNC').lower() == 'false':
|
||
resp = {"status": "success", "message": "syncing"}
|
||
logger.error("other process doing ,wait and try again")
|
||
return ok(resp)
|
||
|
||
os.environ['ALREADY_SYNC'] = 'false'
|
||
logger.info(f"sync_instance start !! {datetime.datetime.now().isoformat()} {request}")
|
||
try:
|
||
last_sync_record = get_last_ddb_sync_record()
|
||
if not last_sync_record:
|
||
logger.info("no last sync record found do not need sync")
|
||
sync_instance_monitor_status(True)
|
||
resp = {"status": "success", "message": "no sync"}
|
||
os.environ['ALREADY_SYNC'] = 'true'
|
||
return ok(resp)
|
||
|
||
if ('request_id' in last_sync_record and last_sync_record['request_id']
|
||
and os.environ.get('LAST_SYNC_REQUEST_ID')
|
||
and os.environ.get('LAST_SYNC_REQUEST_ID') == last_sync_record['request_id']
|
||
and os.environ.get('LAST_SYNC_REQUEST_TIME')
|
||
and os.environ.get('LAST_SYNC_REQUEST_TIME') == str(last_sync_record['request_time'])):
|
||
logger.info("last sync record already sync by os check")
|
||
sync_instance_monitor_status(False)
|
||
resp = {"status": "success", "message": "no sync env"}
|
||
os.environ['ALREADY_SYNC'] = 'true'
|
||
return ok(resp)
|
||
|
||
instance_monitor_record = get_latest_ddb_instance_monitor_record()
|
||
if not instance_monitor_record:
|
||
sync_already = await prepare_comfy_env(last_sync_record)
|
||
if sync_already:
|
||
logger.info("should init prepare instance_monitor_record")
|
||
sync_status = 'success' if sync_already else 'failed'
|
||
save_sync_instance_monitor(last_sync_record['request_id'], sync_status)
|
||
else:
|
||
sync_instance_monitor_status(False)
|
||
else:
|
||
if ('last_sync_request_id' in instance_monitor_record
|
||
and instance_monitor_record['last_sync_request_id']
|
||
and instance_monitor_record['last_sync_request_id'] == last_sync_record['request_id']
|
||
and instance_monitor_record['sync_status']
|
||
and instance_monitor_record['sync_status'] == 'success'
|
||
and os.environ.get('LAST_SYNC_REQUEST_TIME')
|
||
and os.environ.get('LAST_SYNC_REQUEST_TIME') == str(last_sync_record['request_time'])):
|
||
logger.info("last sync record already sync")
|
||
sync_instance_monitor_status(False)
|
||
resp = {"status": "success", "message": "no sync ddb"}
|
||
os.environ['ALREADY_SYNC'] = 'true'
|
||
return ok(resp)
|
||
|
||
sync_already = await prepare_comfy_env(last_sync_record)
|
||
instance_monitor_record['sync_status'] = 'success' if sync_already else 'failed'
|
||
instance_monitor_record['last_sync_request_id'] = last_sync_record['request_id']
|
||
sync_list = instance_monitor_record['sync_list'] if ('sync_list' in instance_monitor_record
|
||
and instance_monitor_record['sync_list']) else []
|
||
sync_list.append(last_sync_record['request_id'])
|
||
|
||
instance_monitor_record['sync_list'] = sync_list
|
||
logger.info("should update prepare instance_monitor_record")
|
||
update_sync_instance_monitor(instance_monitor_record)
|
||
os.environ['ALREADY_SYNC'] = 'true'
|
||
resp = {"status": "success", "message": "sync"}
|
||
return ok(resp)
|
||
except Exception as e:
|
||
logger.info("exception occurred", e)
|
||
os.environ['ALREADY_SYNC'] = 'true'
|
||
resp = {"status": "fail", "message": "sync"}
|
||
return error(resp)
|
||
|
||
|
||
def validate_prompt_proxy(func):
|
||
def wrapper(*args, **kwargs):
|
||
# 在这里添加您的代理逻辑
|
||
logger.info("validate_prompt_proxy start...")
|
||
# 调用原始函数
|
||
result = func(*args, **kwargs)
|
||
# 在这里添加执行后的操作
|
||
logger.info("validate_prompt_proxy end...")
|
||
return result
|
||
|
||
return wrapper
|
||
|
||
|
||
execution.validate_prompt = validate_prompt_proxy(execution.validate_prompt)
|
||
|
||
|
||
def send_sync_proxy(func):
|
||
def wrapper(*args, **kwargs):
|
||
logger.debug(f"Sending sync request!!!!!!! {args}")
|
||
global need_sync
|
||
global prompt_id
|
||
logger.info(f"send_sync_proxy start... {need_sync},{prompt_id} {args}")
|
||
func(*args, **kwargs)
|
||
if need_sync and QUEUE_URL and REGION:
|
||
logger.debug(f"send_sync_proxy params... {QUEUE_URL},{REGION},{need_sync},{prompt_id}")
|
||
event = args[1]
|
||
data = args[2]
|
||
sid = args[3] if len(args) == 4 else None
|
||
message_body = {'prompt_id': prompt_id, 'event': event, 'data': data, 'sid': sid}
|
||
message_id = sen_sqs_msg(message_body, prompt_id)
|
||
logger.info(f'send_sync_proxy message_id :{message_id} message_body: {message_body}')
|
||
logger.debug(f"send_sync_proxy end...")
|
||
|
||
return wrapper
|
||
|
||
|
||
server.PromptServer.send_sync = send_sync_proxy(server.PromptServer.send_sync)
|
||
|
||
|
||
def get_save_imge_path_proxy(func):
|
||
def wrapper(*args, **kwargs):
|
||
logger.info(f"get_save_imge_path_proxy args : {args} kwargs : {kwargs}")
|
||
full_output_folder, filename, counter, subfolder, filename_prefix = func(*args, **kwargs)
|
||
global prompt_id
|
||
filename_prefix_new = filename_prefix + "_" + str(prompt_id)
|
||
logger.info(f"get_save_imge_path_proxy filename_prefix new : {filename_prefix_new}")
|
||
return full_output_folder, filename, counter, subfolder, filename_prefix_new
|
||
|
||
return wrapper
|
||
|
||
|
||
folder_paths.get_save_image_path = get_save_imge_path_proxy(folder_paths.get_save_image_path)
|