2246 lines
109 KiB
Python
Executable File
2246 lines
109 KiB
Python
Executable File
import concurrent.futures
|
||
import re
|
||
import signal
|
||
import threading
|
||
|
||
import boto3
|
||
import requests
|
||
from aiohttp import web
|
||
|
||
import folder_paths
|
||
import server
|
||
from execution import PromptExecutor
|
||
import execution
|
||
import comfy
|
||
|
||
from watchdog.observers import Observer
|
||
from watchdog.events import FileSystemEventHandler
|
||
import subprocess
|
||
from dotenv import load_dotenv
|
||
|
||
import fcntl
|
||
import hashlib
|
||
|
||
import base64
|
||
import datetime
|
||
import json
|
||
import logging
|
||
import os
|
||
import sys
|
||
import tarfile
|
||
import time
|
||
import uuid
|
||
import gc
|
||
from dataclasses import dataclass
|
||
from typing import Optional
|
||
|
||
from boto3.dynamodb.conditions import Key
|
||
|
||
DISABLE_AWS_PROXY = 'DISABLE_AWS_PROXY'
|
||
sync_msg_list = []
|
||
client_release_map = {}
|
||
lock_status = False
|
||
logging.basicConfig(level=logging.DEBUG)
|
||
logger = logging.getLogger(__name__)
|
||
logger.setLevel(logging.INFO)
|
||
|
||
is_on_sagemaker = os.getenv('ON_SAGEMAKER') == 'true'
|
||
is_on_ec2 = os.getenv('ON_EC2') == 'true'
|
||
|
||
if is_on_ec2:
|
||
env_path = '/etc/environment'
|
||
|
||
if 'ENV_FILE_PATH' in os.environ and os.environ.get('ENV_FILE_PATH'):
|
||
env_path = os.environ.get('ENV_FILE_PATH')
|
||
|
||
load_dotenv('/etc/environment')
|
||
logger.info(f"env_path{env_path}")
|
||
|
||
env_keys = ['ENV_FILE_PATH', 'COMFY_INPUT_PATH', 'COMFY_MODEL_PATH', 'COMFY_NODE_PATH', 'COMFY_API_URL',
|
||
'COMFY_API_TOKEN', 'COMFY_ENDPOINT', 'COMFY_NEED_SYNC', 'COMFY_NEED_PREPARE', 'COMFY_BUCKET_NAME',
|
||
'MAX_WAIT_TIME', 'MSG_MAX_WAIT_TIME', 'THREAD_MAX_WAIT_TIME', DISABLE_AWS_PROXY, 'DISABLE_AUTO_SYNC']
|
||
|
||
for item in os.environ.keys():
|
||
if item in env_keys:
|
||
logger.info(f'evn key: {item} {os.environ.get(item)}')
|
||
|
||
DIR3 = f"/container/workflows/{os.getenv('WORKFLOW_NAME')}/ComfyUI/input"
|
||
DIR1 = f"/container/workflows/{os.getenv('WORKFLOW_NAME')}/ComfyUI/models"
|
||
DIR2 = f"/container/workflows/{os.getenv('WORKFLOW_NAME')}/ComfyUI/custom_nodes"
|
||
|
||
if 'COMFY_INPUT_PATH' in os.environ and os.environ.get('COMFY_INPUT_PATH'):
|
||
DIR3 = os.environ.get('COMFY_INPUT_PATH')
|
||
if 'COMFY_MODEL_PATH' in os.environ and os.environ.get('COMFY_MODEL_PATH'):
|
||
DIR1 = os.environ.get('COMFY_MODEL_PATH')
|
||
if 'COMFY_NODE_PATH' in os.environ and os.environ.get('COMFY_NODE_PATH'):
|
||
DIR2 = os.environ.get('COMFY_NODE_PATH')
|
||
|
||
api_url = os.environ.get('COMFY_API_URL')
|
||
api_token = os.environ.get('COMFY_API_TOKEN')
|
||
comfy_need_sync = os.environ.get('COMFY_NEED_SYNC', True)
|
||
comfy_need_prepare = os.environ.get('COMFY_NEED_PREPARE', False)
|
||
bucket_name = os.environ.get('COMFY_BUCKET_NAME')
|
||
thread_max_wait_time = os.environ.get('THREAD_MAX_WAIT_TIME', 60)
|
||
max_wait_time = os.environ.get('MAX_WAIT_TIME', 86400)
|
||
msg_max_wait_time = os.environ.get('MSG_MAX_WAIT_TIME', 86400)
|
||
is_master_process = os.getenv('MASTER_PROCESS') == 'true'
|
||
program_name = os.getenv('PROGRAM_NAME')
|
||
no_need_sync_files = ['.autosave', '.cache', '.autosave1', '~', '.swp']
|
||
|
||
need_resend_msg_result = []
|
||
PREPARE_ID = 'default'
|
||
# additional
|
||
PREPARE_MODE = 'additional'
|
||
|
||
if not api_url:
|
||
raise ValueError("API_URL environment variables must be set.")
|
||
|
||
if not api_token:
|
||
raise ValueError("API_TOKEN environment variables must be set.")
|
||
|
||
headers = {"x-api-key": api_token, "Content-Type": "application/json", "username": "api"}
|
||
|
||
|
||
def send_msg_to_all_sockets(event: str, msg: dict):
|
||
sockets = server.PromptServer.instance.sockets
|
||
for socket in sockets.keys():
|
||
client_id = socket
|
||
server.PromptServer.instance.loop.call_soon_threadsafe(
|
||
server.PromptServer.instance.messages.put_nowait, (event, msg, client_id))
|
||
|
||
def get_endpoint_name_by_workflow_name(name: str, endpoint_type: str = 'async'):
|
||
return f"comfy-{endpoint_type}-{name}"
|
||
|
||
|
||
def save_images_locally(response_json, local_folder):
|
||
try:
|
||
data = response_json.get("data", {})
|
||
prompt_id = data.get("prompt_id")
|
||
image_video_data = data.get("image_video_data", {})
|
||
|
||
if not prompt_id or not image_video_data:
|
||
logger.info("Missing prompt_id or image_video_data in the response.")
|
||
return
|
||
|
||
folder_path = os.path.join(local_folder, prompt_id)
|
||
os.makedirs(folder_path, exist_ok=True)
|
||
|
||
for image_name, image_url in image_video_data.items():
|
||
image_response = requests.get(image_url)
|
||
if image_response.status_code == 200:
|
||
image_path = os.path.join(folder_path, image_name)
|
||
with open(image_path, "wb") as image_file:
|
||
image_file.write(image_response.content)
|
||
logger.info(f"Image '{image_name}' saved to {image_path}")
|
||
else:
|
||
logger.info(
|
||
f"Failed to download image '{image_name}' from {image_url}. Status code: {image_response.status_code}")
|
||
|
||
except Exception as e:
|
||
logger.info(f"Error saving images locally: {e}")
|
||
|
||
|
||
def calculate_file_hash(file_path):
|
||
# 创建一个哈希对象
|
||
hasher = hashlib.sha256()
|
||
# 打开文件并逐块更新哈希对象
|
||
with open(file_path, 'rb') as file:
|
||
buffer = file.read(65536) # 64KB 的缓冲区大小
|
||
while len(buffer) > 0:
|
||
hasher.update(buffer)
|
||
buffer = file.read(65536)
|
||
# 返回哈希值的十六进制表示
|
||
return hasher.hexdigest()
|
||
|
||
|
||
def save_files(prefix, execute, key, target_dir, need_prefix):
|
||
if key in execute['data']:
|
||
temp_files = execute['data'][key]
|
||
for url in temp_files:
|
||
loca_file = get_file_name(url)
|
||
response = requests.get(url)
|
||
# if target_dir not exists, create it
|
||
if not os.path.exists(target_dir):
|
||
os.makedirs(target_dir)
|
||
logger.info(f"Saving file {loca_file} to {target_dir}")
|
||
if loca_file.endswith("output_images_will_be_put_here"):
|
||
continue
|
||
if need_prefix:
|
||
with open(f"{target_dir}/{prefix}_{loca_file}", 'wb') as f:
|
||
f.write(response.content)
|
||
# current override exist
|
||
with open(f"{target_dir}/{loca_file}", 'wb') as f:
|
||
f.write(response.content)
|
||
else:
|
||
with open(f"{target_dir}/{loca_file}", 'wb') as f:
|
||
f.write(response.content)
|
||
|
||
|
||
def get_file_name(url: str):
|
||
file_name = url.split('/')[-1]
|
||
file_name = file_name.split('?')[0]
|
||
return file_name
|
||
|
||
|
||
def send_service_msg(server_use, msg):
|
||
event = msg.get('event')
|
||
data = msg.get('data')
|
||
sid = msg.get('sid') if 'sid' in msg else None
|
||
server_use.send_sync(event, data, sid)
|
||
|
||
|
||
def handle_sync_messages(server_use, msg_array):
|
||
already_synced = False
|
||
global sync_msg_list
|
||
for msg in msg_array:
|
||
for item_msg in msg:
|
||
event = item_msg.get('event')
|
||
data = item_msg.get('data')
|
||
sid = item_msg.get('sid') if 'sid' in item_msg else None
|
||
if data in sync_msg_list:
|
||
continue
|
||
sync_msg_list.append(data)
|
||
if event == 'finish':
|
||
already_synced = True
|
||
elif event == 'executed':
|
||
global need_resend_msg_result
|
||
need_resend_msg_result.append(msg)
|
||
server_use.send_sync(event, data, sid)
|
||
|
||
return already_synced
|
||
|
||
|
||
def execute_proxy(func):
|
||
def wrapper(*args, **kwargs):
|
||
|
||
def send_error_msg(executor, prompt_id, msg):
|
||
mes = {
|
||
"prompt_id": prompt_id,
|
||
"node_id": "",
|
||
"node_type": "on cloud",
|
||
"executed": [],
|
||
"exception_message": msg,
|
||
"exception_type": "",
|
||
"traceback": [],
|
||
"current_inputs": "",
|
||
"current_outputs": "",
|
||
}
|
||
executor.add_message("execution_error", mes, broadcast=True)
|
||
|
||
if is_master_process and 'True' == os.environ.get(DISABLE_AWS_PROXY):
|
||
logger.info("disabled aws proxy, use local")
|
||
return func(*args, **kwargs)
|
||
logger.info(f"enable aws proxy, use aws")
|
||
executor = args[0]
|
||
server_use = executor.server
|
||
prompt = args[1]
|
||
prompt_id = args[2]
|
||
extra_data = args[3]
|
||
|
||
client_id = extra_data['client_id'] if 'client_id' in extra_data else None
|
||
if not client_id:
|
||
send_error_msg(executor, prompt_id,
|
||
f"Something went wrong when execute,please check your client_id and try again")
|
||
return web.Response()
|
||
global client_release_map
|
||
workflow_name = client_release_map.get(client_id) if client_release_map.get(client_id) else os.getenv('WORKFLOW_NAME')
|
||
if not workflow_name or workflow_name == 'default':
|
||
send_error_msg(executor, prompt_id, f"Please choose a release env before you execute prompt")
|
||
return web.Response()
|
||
# if not is_master_process:
|
||
# send_error_msg(executor, prompt_id, f"Please choose a release env before you execute prompt")
|
||
# return web.Response()
|
||
# elif not get_endpoint_name_by_workflow_name(workflow_name):
|
||
# send_error_msg(executor, prompt_id, f"Please check your endpoint:{get_endpoint_name_by_workflow_name(workflow_name)} before you execute prompt")
|
||
# return web.Response()
|
||
# else:
|
||
# comfy_endpoint = get_endpoint_name_by_workflow_name(workflow_name)
|
||
logger.info(f"use endpoint:{get_endpoint_name_by_workflow_name(workflow_name)} workflow:{workflow_name} api: {api_url}to generate")
|
||
|
||
payload = {
|
||
"number": str(server.PromptServer.instance.number),
|
||
"prompt": prompt,
|
||
"prompt_id": prompt_id,
|
||
"extra_data": extra_data,
|
||
"endpoint_name": get_endpoint_name_by_workflow_name(workflow_name),
|
||
"need_prepare": comfy_need_prepare,
|
||
"need_sync": comfy_need_sync,
|
||
"multi_async": False,
|
||
"workflow_name": workflow_name,
|
||
}
|
||
|
||
def send_post_request(url, params):
|
||
logger.debug(f"sending post request {url} , params {params}")
|
||
get_response = requests.post(url, json=params, headers=headers)
|
||
return get_response
|
||
|
||
def send_get_request(url):
|
||
get_response = requests.get(url, headers=headers)
|
||
return get_response
|
||
|
||
def check_if_sync_is_already(url):
|
||
get_response = send_get_request(url)
|
||
prepare_response = get_response.json()
|
||
if (prepare_response['statusCode'] == 200 and 'data' in prepare_response and prepare_response['data']
|
||
and prepare_response['data']['prepareSuccess']):
|
||
logger.info(f"sync available")
|
||
return True
|
||
else:
|
||
logger.info(f"no sync available for {url} response {prepare_response}")
|
||
return False
|
||
|
||
logger.info(f"payload is: {payload}")
|
||
# is_synced = check_if_sync_is_already(f"{api_url}/prepare/{get_endpoint_name_by_workflow_name(workflow_name)}")
|
||
# if not is_synced:
|
||
# logger.debug(f"is_synced is {is_synced} stop cloud prompt")
|
||
# send_error_msg(executor, prompt_id,
|
||
# "Your local environment has not compleated to synchronized on cloud already. Please wait for a moment or click the 'Synchronize' button .")
|
||
# return
|
||
|
||
with concurrent.futures.ThreadPoolExecutor() as executorThread:
|
||
execute_future = executorThread.submit(send_post_request, f"{api_url}/executes", payload)
|
||
|
||
save_already = False
|
||
if comfy_need_sync:
|
||
msg_future = executorThread.submit(send_get_request,
|
||
f"{api_url}/sync/{prompt_id}")
|
||
done, _ = concurrent.futures.wait([execute_future, msg_future],
|
||
return_when=concurrent.futures.ALL_COMPLETED)
|
||
already_synced = False
|
||
global sync_msg_list
|
||
sync_msg_list = []
|
||
for future in done:
|
||
if future == msg_future:
|
||
msg_response = future.result()
|
||
logger.info(f"get syc msg: {msg_response.json()}")
|
||
if msg_response.status_code == 200:
|
||
if 'data' not in msg_response.json() or not msg_response.json().get("data"):
|
||
logger.error("there is no response from sync msg by thread ")
|
||
time.sleep(1)
|
||
else:
|
||
logger.debug(msg_response.json())
|
||
already_synced = handle_sync_messages(server_use, msg_response.json().get("data"))
|
||
elif future == execute_future:
|
||
execute_resp = future.result()
|
||
logger.info(f"get execute status: {execute_resp.status_code}")
|
||
if execute_resp.status_code == 200 or execute_resp.status_code == 201 or execute_resp.status_code == 202:
|
||
i = thread_max_wait_time
|
||
while i > 0:
|
||
images_response = send_get_request(f"{api_url}/executes/{prompt_id}")
|
||
response = images_response.json()
|
||
logger.info(f"get execute images: {images_response.status_code}")
|
||
if images_response.status_code == 404:
|
||
logger.info("no images found already ,waiting sagemaker thread result .....")
|
||
time.sleep(3)
|
||
i = i - 2
|
||
elif response['data']['status'] == 'failed':
|
||
logger.error(
|
||
f"there is no response on sagemaker from execute thread result !!!!!!!! ")
|
||
# send_error_msg(executor, prompt_id,
|
||
# f"There may be some errors when valid and execute the prompt on the cloud. Please check the SageMaker logs. error info: {response['data']['message']}")
|
||
# no need to send msg anymore
|
||
already_synced = True
|
||
break
|
||
elif response['data']['status'] != 'Completed' and response['data'][
|
||
'status'] != 'success':
|
||
logger.info(
|
||
f"no images found already ,waiting sagemaker thread result, current status is {response['data']['status']}")
|
||
time.sleep(2)
|
||
i = i - 1
|
||
elif 'data' not in response or not response['data'] or 'status' not in response[
|
||
'data'] or not response['data']['status']:
|
||
logger.error(
|
||
f"there is no response from execute thread result !!!!!!!! {response}")
|
||
# no need to send msg anymore
|
||
already_synced = True
|
||
# send_error_msg(executor, prompt_id,"There may be some errors when executing the prompt on cloud. No images or videos generated.")
|
||
break
|
||
else:
|
||
if ('temp_files' in images_response.json()['data'] and len(
|
||
images_response.json()['data']['temp_files']) > 0) or ((
|
||
'output_files' in images_response.json()['data'] and len(
|
||
images_response.json()['data']['output_files']) > 0)):
|
||
logger.info(f"save images to default")
|
||
save_files(prompt_id, images_response.json(), 'temp_files', './temp', False)
|
||
save_files(prompt_id, images_response.json(), 'output_files', './output',
|
||
True)
|
||
output_dir = folder_paths.get_output_directory()
|
||
temp_dir = folder_paths.get_temp_directory()
|
||
logger.info(f"save images to {output_dir} and {temp_dir}")
|
||
save_files(prompt_id, images_response.json(), 'temp_files', temp_dir, False)
|
||
save_files(prompt_id, images_response.json(), 'output_files', output_dir,
|
||
True)
|
||
|
||
else:
|
||
send_error_msg(executor, prompt_id,
|
||
"There may be some errors when executing the prompt on the cloud. Please check the SageMaker logs.")
|
||
# no need to send msg anymore
|
||
already_synced = True
|
||
logger.debug(images_response.json())
|
||
save_already = True
|
||
break
|
||
else:
|
||
logger.error(f"get execute error: {execute_resp}")
|
||
# send_error_msg(executor, prompt_id, "Please valid your prompt and try again.")
|
||
# send_error_msg(executor, prompt_id,
|
||
# f"There may be some errors when valid and execute the prompt on the cloud. Please check the SageMaker logs. error info: {response['data']['message']}")
|
||
# no need to send msg anymore
|
||
already_synced = True
|
||
break
|
||
logger.debug(execute_resp.json())
|
||
|
||
m = msg_max_wait_time
|
||
while not already_synced:
|
||
msg_response = send_get_request(f"{api_url}/sync/{prompt_id}")
|
||
# logger.info(msg_response.json())
|
||
if msg_response.status_code == 200:
|
||
if m <= 0:
|
||
logger.error("there is no response from sync msg by timeout")
|
||
already_synced = True
|
||
elif 'data' not in msg_response.json() or not msg_response.json().get("data"):
|
||
logger.error("there is no response from sync msg")
|
||
time.sleep(1)
|
||
m = m - 1
|
||
else:
|
||
logger.debug(msg_response.json())
|
||
already_synced = handle_sync_messages(server_use, msg_response.json().get("data"))
|
||
logger.info(f"already_synced is :{already_synced}")
|
||
time.sleep(1)
|
||
m = m - 1
|
||
logger.info(f"check if images are already synced {save_already}")
|
||
|
||
if not save_already:
|
||
logger.info("check if images are not already synced, please wait")
|
||
execute_resp = execute_future.result()
|
||
logger.debug(f"execute result :{execute_resp.json()}")
|
||
if execute_resp.status_code == 200 or execute_resp.status_code == 201 or execute_resp.status_code == 202:
|
||
i = max_wait_time
|
||
while i > 0:
|
||
images_response = send_get_request(f"{api_url}/executes/{prompt_id}")
|
||
response = images_response.json()
|
||
logger.debug(response)
|
||
if images_response.status_code == 404:
|
||
logger.info(f"{i} no images found already ,waiting sagemaker result .....")
|
||
i = i - 2
|
||
time.sleep(3)
|
||
elif response['data']['status'] == 'failed':
|
||
logger.error(
|
||
f"there is no response on sagemaker from execute result !!!!!!!! ")
|
||
if 'message' in response['data'] and response['data']['message']:
|
||
send_error_msg(executor, prompt_id,
|
||
f"There may be some errors when valid or execute the prompt on the cloud. Please check the SageMaker logs. errors: {response['data']['message']}")
|
||
break
|
||
else:
|
||
logger.error(f"valid error on sagemaker :{response['data']}")
|
||
send_error_msg(executor, prompt_id,
|
||
f"There may be some errors when valid or execute the prompt on the cloud. errors")
|
||
break
|
||
elif response['data']['status'] != 'Completed' and response['data']['status'] != 'success':
|
||
logger.info(
|
||
f"{i} images not already ,waiting sagemaker result .....{response['data']['status']}")
|
||
i = i - 1
|
||
time.sleep(3)
|
||
elif 'data' not in response or not response['data'] or 'status' not in response[
|
||
'data'] or not response['data']['status']:
|
||
logger.info(f"{i} there is no response from sync executes {response}")
|
||
send_error_msg(executor, prompt_id,
|
||
f"There may be some errors when executing the prompt on the cloud. No images or videos generated. {response['message']}")
|
||
break
|
||
elif response['data']['status'] == 'Completed' or response['data']['status'] == 'success':
|
||
if ('temp_files' in images_response.json()['data'] and len(
|
||
images_response.json()['data']['temp_files']) > 0) or ((
|
||
'output_files' in images_response.json()['data'] and len(
|
||
images_response.json()['data']['output_files']) > 0)):
|
||
save_files(prompt_id, images_response.json(), 'temp_files', './temp', False)
|
||
save_files(prompt_id, images_response.json(), 'output_files', './output', True)
|
||
|
||
output_dir = folder_paths.get_output_directory()
|
||
temp_dir = folder_paths.get_temp_directory()
|
||
logger.info(f"save images to {output_dir} and {temp_dir}")
|
||
save_files(prompt_id, images_response.json(), 'temp_files', temp_dir, False)
|
||
save_files(prompt_id, images_response.json(), 'output_files', output_dir,
|
||
True)
|
||
break
|
||
else:
|
||
send_error_msg(executor, prompt_id,
|
||
"There may be some errors when executing the prompt on the cloud. Please check the SageMaker logs.")
|
||
break
|
||
else:
|
||
# logger.info(
|
||
# f"{i} images not already other,waiting sagemaker result .....{response}")
|
||
# i = i - 1
|
||
# time.sleep(3)
|
||
send_error_msg(executor, prompt_id,
|
||
"You have some errors when execute prompt on cloud . Please check your sagemaker logs.")
|
||
break
|
||
else:
|
||
logger.error(f"get execute error: {execute_resp.json()}")
|
||
send_error_msg(executor, prompt_id, "Please valid your prompt and try again." if not (execute_resp.json() and execute_resp.json().get("message")) else execute_resp.json().get("message"))
|
||
logger.info("execute finished")
|
||
executorThread.shutdown()
|
||
|
||
return wrapper
|
||
|
||
|
||
PromptExecutor.execute = execute_proxy(PromptExecutor.execute)
|
||
|
||
|
||
def send_sync_proxy(func):
|
||
def wrapper(*args, **kwargs):
|
||
logger.info(f"Sending sync request----- {args}")
|
||
return func(*args, **kwargs)
|
||
|
||
return wrapper
|
||
|
||
|
||
server.PromptServer.send_sync = send_sync_proxy(server.PromptServer.send_sync)
|
||
|
||
|
||
def compress_and_upload(comfy_endpoint, folder_path, prepare_version):
|
||
for subdir in next(os.walk(folder_path))[1]:
|
||
subdir_path = os.path.join(folder_path, subdir)
|
||
tar_filename = f"{subdir}.tar.gz"
|
||
logger.info(f"Compressing the {tar_filename}")
|
||
with tarfile.open(tar_filename, "w:gz") as tar:
|
||
tar.add(subdir_path, arcname=os.path.basename(subdir_path))
|
||
s5cmd_syn_node_command = f's5cmd --log=error cp {tar_filename} "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/custom_nodes/"'
|
||
logger.info(s5cmd_syn_node_command)
|
||
os.system(s5cmd_syn_node_command)
|
||
logger.info(f"rm {tar_filename}")
|
||
os.remove(tar_filename)
|
||
|
||
# for root, dirs, files in os.walk(folder_path):
|
||
# for directory in dirs:
|
||
# dir_path = os.path.join(root, directory)
|
||
# logger.info(f"Compressing the {dir_path}")
|
||
# tar_filename = f"{directory}.tar.gz"
|
||
# tar_filepath = os.path.join(root, tar_filename)
|
||
# with tarfile.open(tar_filepath, "w:gz") as tar:
|
||
# tar.add(dir_path, arcname=os.path.basename(dir_path))
|
||
# s5cmd_syn_node_command = f's5cmd --log=error cp {tar_filepath} "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"'
|
||
# logger.info(s5cmd_syn_node_command)
|
||
# os.system(s5cmd_syn_node_command)
|
||
# logger.info(f"rm {tar_filepath}")
|
||
# os.remove(tar_filepath)
|
||
|
||
|
||
def sync_default_files(comfy_endpoint, prepare_type):
|
||
try:
|
||
timestamp = str(int(time.time() * 1000))
|
||
prepare_version = PREPARE_ID if PREPARE_MODE == 'additional' else timestamp
|
||
need_prepare = True
|
||
need_reboot = False
|
||
# logger.info(f" sync custom nodes files")
|
||
# s5cmd_syn_node_command = f's5cmd --log=error sync --delete=true --exclude="*comfy_local_proxy.py" {DIR2}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"'
|
||
# logger.info(f"sync custom_nodes files start {s5cmd_syn_node_command}")
|
||
# os.system(s5cmd_syn_node_command)
|
||
# compress_and_upload(comfy_endpoint, f"{DIR2}", prepare_version)
|
||
if prepare_type in ['default', 'inputs']:
|
||
logger.info(f" sync input files")
|
||
# s5cmd_syn_input_command = f's5cmd --log=error sync --delete=true {DIR3}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/input/"'
|
||
s5cmd_syn_input_command = f'aws s3 sync --delete {DIR3}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/input/"'
|
||
logger.info(f"sync input files start {s5cmd_syn_input_command}")
|
||
os.system(s5cmd_syn_input_command)
|
||
if prepare_type in ['default', 'models']:
|
||
logger.info(f" sync models files")
|
||
# s5cmd_syn_model_command = f's5cmd --log=error sync --delete=true {DIR1}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/models/"'
|
||
s5cmd_syn_model_command = f'aws s3 sync --delete {DIR1}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/models/"'
|
||
logger.info(f"sync models files start {s5cmd_syn_model_command}")
|
||
os.system(s5cmd_syn_model_command)
|
||
logger.info(f"Files changed in:: {need_prepare} {prepare_type} {DIR2} {DIR1} {DIR3}")
|
||
|
||
url = api_url + "prepare"
|
||
logger.info(f"URL:{url}")
|
||
data = {"endpoint_name": comfy_endpoint, "need_reboot": need_reboot, "prepare_id": prepare_version,
|
||
"prepare_type": prepare_type}
|
||
logger.info(f"prepare params Data: {json.dumps(data, indent=4)}")
|
||
result = subprocess.run(["curl", "--location", "--request", "POST", url, "--header",
|
||
f"x-api-key: {api_token}", "--data-raw", json.dumps(data)],
|
||
capture_output=True, text=True)
|
||
logger.info(result.stdout)
|
||
|
||
return result.stdout
|
||
except Exception as e:
|
||
logger.info(f"sync_files error {e}")
|
||
return None
|
||
|
||
|
||
def sync_files(filepath, is_folder, is_auto):
|
||
comfy_endpoint = os.getenv("COMFY_ENDPOINT")
|
||
try:
|
||
directory = os.path.dirname(filepath)
|
||
logger.info(f"Directory changed in: {directory} {filepath}")
|
||
if not directory:
|
||
logger.info("root path no need to sync files by duplicate opt")
|
||
return None
|
||
timestamp = str(int(time.time() * 1000))
|
||
logger.info(f"Files changed in: {filepath} time is:{timestamp}")
|
||
need_prepare = False
|
||
prepare_type = 'inputs'
|
||
need_reboot = False
|
||
for ignore_item in no_need_sync_files:
|
||
if filepath.endswith(ignore_item):
|
||
logger.info(f"no need to sync files by ignore files {filepath} ends by {ignore_item}")
|
||
return None
|
||
prepare_version = PREPARE_ID if PREPARE_MODE == 'additional' else timestamp
|
||
if (str(directory).endswith(f"{DIR2}" if DIR2.startswith("/") else f"/{DIR2}")
|
||
or str(filepath) == DIR2 or str(filepath) == f'./{DIR2}' or f"{DIR2}/" in filepath):
|
||
logger.info(f" sync custom nodes files: {filepath}")
|
||
s5cmd_syn_node_command = f's5cmd --log=error sync --delete=true --exclude="*comfy_local_proxy.py" {DIR2}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/custom_nodes/"'
|
||
# s5cmd_syn_node_command = f'aws s3 sync {DIR2}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"'
|
||
# s5cmd_syn_node_command = f's5cmd sync {DIR2}/* "s3://{bucket_name}/comfy/{comfy_endpoint}/{timestamp}/custom_nodes/"'
|
||
|
||
# custom_node文件夹有变化 稍后再同步
|
||
if is_auto and not is_folder_unlocked(directory):
|
||
logger.info("sync custom_nodes files is changing ,waiting.... ")
|
||
return None
|
||
logger.info("sync custom_nodes files start")
|
||
logger.info(s5cmd_syn_node_command)
|
||
os.system(s5cmd_syn_node_command)
|
||
need_prepare = True
|
||
need_reboot = True
|
||
prepare_type = 'nodes'
|
||
elif (str(directory).endswith(f"{DIR3}" if DIR3.startswith("/") else f"/{DIR3}")
|
||
or str(filepath) == DIR3 or str(filepath) == f'./{DIR3}' or f"{DIR3}/" in filepath):
|
||
logger.info(f" sync input files: {filepath}")
|
||
s5cmd_syn_input_command = f'aws s3 sync --delete {DIR3}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/input/"'
|
||
# s5cmd_syn_input_command = f'/usr/local/bin/s5cmd sync {DIR3}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/input/"'
|
||
|
||
# 判断文件写完后再同步
|
||
if is_auto:
|
||
if bool(is_folder):
|
||
can_sync = is_folder_unlocked(filepath)
|
||
else:
|
||
can_sync = is_file_unlocked(filepath)
|
||
if not can_sync:
|
||
logger.info("sync input files is changing ,waiting.... ")
|
||
return None
|
||
logger.info("sync input files start")
|
||
logger.info(s5cmd_syn_input_command)
|
||
os.system(s5cmd_syn_input_command)
|
||
# result = subprocess.run(s5cmd_syn_input_command, shell=True, check=True, stdout=subprocess.PIPE)
|
||
# logger.info(result.stdout.decode())
|
||
need_prepare = True
|
||
prepare_type = 'inputs'
|
||
elif (str(directory).endswith(f"{DIR1}" if DIR1.startswith("/") else f"/{DIR1}")
|
||
or str(filepath) == DIR1 or str(filepath) == f'./{DIR1}' or f"{DIR1}/" in filepath):
|
||
logger.info(f" sync models files: {filepath}")
|
||
s5cmd_syn_model_command = f'aws s3 sync --delete {DIR1}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/models/"'
|
||
# s5cmd_syn_model_command = f'/usr/local/bin/s5cmd sync {DIR1}/ "s3://{bucket_name}/comfy/{comfy_endpoint}/{prepare_version}/models/"'
|
||
|
||
# 判断文件写完后再同步
|
||
if is_auto:
|
||
if bool(is_folder):
|
||
can_sync = is_folder_unlocked(filepath)
|
||
else:
|
||
can_sync = is_file_unlocked(filepath)
|
||
# logger.info(f'is folder {directory} {is_folder} can_sync {can_sync}')
|
||
if not can_sync:
|
||
logger.info("sync input models is changing ,waiting.... ")
|
||
return None
|
||
|
||
logger.info("sync models files start")
|
||
logger.info(s5cmd_syn_model_command)
|
||
os.system(s5cmd_syn_model_command)
|
||
# result = subprocess.run(s5cmd_syn_model_command, shell=True, check=True, stdout=subprocess.PIPE)
|
||
# logger.info(result.stdout.decode())
|
||
need_prepare = True
|
||
prepare_type = 'models'
|
||
timestamp_sync = str(int(time.time() * 1000))
|
||
logger.info(f"Files changed in:: {need_prepare} {str(directory)} {DIR2} {DIR1} {DIR3}, time is:{timestamp_sync}")
|
||
if need_prepare:
|
||
url = api_url + "prepare"
|
||
logger.info(f"URL:{url}")
|
||
data = {"endpoint_name": comfy_endpoint, "need_reboot": need_reboot, "prepare_id": prepare_version,
|
||
"prepare_type": prepare_type}
|
||
logger.info(f"prepare params Data: {json.dumps(data, indent=4)}")
|
||
result = subprocess.run(["curl", "--location", "--request", "POST", url, "--header",
|
||
f"x-api-key: {api_token}", "--data-raw", json.dumps(data)],
|
||
capture_output=True, text=True)
|
||
logger.info(result.stdout)
|
||
timestamp_prepare = str(int(time.time() * 1000))
|
||
logger.info(f"finish prepare in : {timestamp_prepare}")
|
||
return result.stdout
|
||
return None
|
||
except Exception as e:
|
||
logger.info(f"sync_files error {e}")
|
||
return None
|
||
|
||
|
||
def is_folder_unlocked(directory):
|
||
# logger.info("check if folder ")
|
||
event_handler = MyHandlerWithCheck()
|
||
observer = Observer()
|
||
observer.schedule(event_handler, directory, recursive=True)
|
||
observer.start()
|
||
time.sleep(1)
|
||
result = False
|
||
try:
|
||
if event_handler.file_changed:
|
||
logger.info(f"folder {directory} is still changing..")
|
||
event_handler.file_changed = False
|
||
time.sleep(1)
|
||
if event_handler.file_changed:
|
||
logger.info(f"folder {directory} is still still changing..")
|
||
else:
|
||
logger.info(f"folder {directory} changing stopped")
|
||
result = True
|
||
else:
|
||
logger.info(f"folder {directory} not stopped")
|
||
result = True
|
||
except (KeyboardInterrupt, Exception) as e:
|
||
logger.info(f"folder {directory} changed exception {e}")
|
||
observer.stop()
|
||
return result
|
||
|
||
|
||
def is_file_unlocked(file_path):
|
||
# logger.info("check if file ")
|
||
try:
|
||
initial_size = os.path.getsize(file_path)
|
||
initial_mtime = os.path.getmtime(file_path)
|
||
time.sleep(1)
|
||
|
||
current_size = os.path.getsize(file_path)
|
||
current_mtime = os.path.getmtime(file_path)
|
||
if current_size != initial_size or current_mtime != initial_mtime:
|
||
logger.info(f"unlock file error {file_path} is changing")
|
||
return False
|
||
|
||
with open(file_path, 'r') as f:
|
||
fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||
return True
|
||
except (IOError, OSError, Exception) as e:
|
||
logger.info(f"unlock file error {file_path} is writing")
|
||
logger.error(e)
|
||
return False
|
||
|
||
|
||
class MyHandlerWithCheck(FileSystemEventHandler):
|
||
def __init__(self):
|
||
self.file_changed = False
|
||
|
||
def on_modified(self, event):
|
||
logger.info(f"custom_node folder is changing {event.src_path}")
|
||
self.file_changed = True
|
||
|
||
def on_deleted(self, event):
|
||
logger.info(f"custom_node folder is changing {event.src_path}")
|
||
self.file_changed = True
|
||
|
||
def on_created(self, event):
|
||
logger.info(f"custom_node folder is changing {event.src_path}")
|
||
self.file_changed = True
|
||
|
||
|
||
class MyHandlerWithSync(FileSystemEventHandler):
|
||
def on_modified(self, event):
|
||
logger.info(f"{datetime.datetime.now()} files modified ,start to sync {event}")
|
||
sync_files(event.src_path, event.is_directory, True)
|
||
|
||
def on_created(self, event):
|
||
logger.info(f"{datetime.datetime.now()} files added ,start to sync {event}")
|
||
sync_files(event.src_path, event.is_directory, True)
|
||
|
||
def on_deleted(self, event):
|
||
logger.info(f"{datetime.datetime.now()} files deleted ,start to sync {event}")
|
||
sync_files(event.src_path, event.is_directory, True)
|
||
|
||
|
||
stop_event = threading.Event()
|
||
|
||
|
||
def check_and_sync():
|
||
logger.info("check_and_sync start")
|
||
event_handler = MyHandlerWithSync()
|
||
observer = Observer()
|
||
try:
|
||
# observer.schedule(event_handler, DIR1, recursive=True)
|
||
# observer.schedule(event_handler, DIR2, recursive=True)
|
||
observer.schedule(event_handler, DIR3, recursive=True)
|
||
observer.start()
|
||
while True:
|
||
time.sleep(1)
|
||
except KeyboardInterrupt:
|
||
logger.info("sync Shutting down please restart ComfyUI")
|
||
observer.stop()
|
||
observer.join()
|
||
|
||
|
||
def signal_handler(sig, frame):
|
||
logger.info("Received termination signal. Exiting...")
|
||
stop_event.set()
|
||
|
||
|
||
if os.environ.get('DISABLE_AUTO_SYNC') == 'false':
|
||
signal.signal(signal.SIGINT, signal_handler)
|
||
signal.signal(signal.SIGTERM, signal_handler)
|
||
|
||
check_sync_thread = threading.Thread(target=check_and_sync)
|
||
check_sync_thread.start()
|
||
|
||
|
||
@server.PromptServer.instance.routes.post('/map_release')
|
||
async def map_release(request):
|
||
logger.info(f"start to map_release {request}")
|
||
json_data = await request.json()
|
||
if (not json_data or 'clientId' not in json_data or not json_data.get('clientId')
|
||
or 'releaseVersion' not in json_data or not json_data.get('releaseVersion')):
|
||
return web.Response(status=500, content_type='application/json', body=json.dumps({"result": False}))
|
||
global client_release_map
|
||
client_release_map[json_data.get('clientId')] = json_data.get('releaseVersion')
|
||
# don‘t move used for sync automic
|
||
os.environ['COMFY_ENDPOINT'] = get_endpoint_name_by_workflow_name(json_data.get('releaseVersion'))
|
||
|
||
logger.info(f"client_release_map :{client_release_map}")
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True}))
|
||
|
||
|
||
@server.PromptServer.instance.routes.get("/reboot")
|
||
async def restart(self):
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "action is not allowed during workflow release/restore"}))
|
||
|
||
if not is_master_process:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "only master can restart"}))
|
||
logger.info(f"start to reboot {self}")
|
||
try:
|
||
from xmlrpc.client import ServerProxy
|
||
server = ServerProxy('http://localhost:9001/RPC2')
|
||
server.supervisor.restart()
|
||
# server.supervisor.shutdown()
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True}))
|
||
except Exception as e:
|
||
logger.info(f"error reboot {e}")
|
||
pass
|
||
return os.execv(sys.executable, [sys.executable] + sys.argv)
|
||
|
||
|
||
@server.PromptServer.instance.routes.post("/check_prepare")
|
||
async def check_prepare(request):
|
||
logger.info(f"start to check_prepare {request}")
|
||
try:
|
||
json_data = await request.json()
|
||
workflow_name = os.getenv('WORKFLOW_NAME')
|
||
comfy_endpoint = get_endpoint_name_by_workflow_name(workflow_name)
|
||
get_response = requests.get(f"{api_url}/prepare/{comfy_endpoint}", headers=headers)
|
||
response = get_response.json()
|
||
logger.info(f"check sync response is {response}")
|
||
if get_response.status_code == 200 and response['data']['prepareSuccess']:
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True}))
|
||
else:
|
||
logger.info(f"check sync response is {response} {response['data']['prepareSuccess']}")
|
||
return web.Response(status=500, content_type='application/json', body=json.dumps({"result": False}))
|
||
except Exception as e:
|
||
logger.info(f"error restart {e}")
|
||
pass
|
||
return os.execv(sys.executable, [sys.executable] + sys.argv)
|
||
|
||
|
||
@server.PromptServer.instance.routes.get("/gc")
|
||
async def gc(self):
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "action is not allowed during workflow release/restore"}))
|
||
|
||
logger.info(f"start to gc {self}")
|
||
try:
|
||
logger.info(f"gc start: {time.time()}")
|
||
server_instance = server.PromptServer.instance
|
||
e = execution.PromptExecutor(server_instance)
|
||
e.reset()
|
||
comfy.model_management.cleanup_models()
|
||
gc.collect()
|
||
comfy.model_management.soft_empty_cache()
|
||
gc_triggered = True
|
||
logger.info(f"gc end: {time.time()}")
|
||
except Exception as e:
|
||
logger.info(f"error restart {e}")
|
||
pass
|
||
return os.execv(sys.executable, [sys.executable] + sys.argv)
|
||
|
||
|
||
def is_action_lock():
|
||
global lock_status
|
||
if lock_status:
|
||
return True
|
||
lock_file = f'/container/sync_lock'
|
||
if os.path.exists(lock_file):
|
||
with open(lock_file, 'r') as f:
|
||
content = f.read()
|
||
if content and not check_workflow_exists(content):
|
||
return True
|
||
return False
|
||
|
||
|
||
def action_lock(name: str):
|
||
global lock_status
|
||
lock_status = True
|
||
send_msg_to_all_sockets("ui_lock", {"lock": True})
|
||
lock_file = f'/container/sync_lock'
|
||
with open(lock_file, 'w') as f:
|
||
f.write(name)
|
||
|
||
|
||
def action_unlock():
|
||
global lock_status
|
||
lock_status = False
|
||
send_msg_to_all_sockets("ui_lock", {"lock": False})
|
||
lock_file = f'/container/sync_lock'
|
||
with open(lock_file, 'w') as f:
|
||
f.write("")
|
||
|
||
|
||
@server.PromptServer.instance.routes.get("/restart")
|
||
async def restart(self):
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "action is not allowed during workflow release/restore"}))
|
||
return restart_response()
|
||
|
||
# only for restart comfy not docker
|
||
@server.PromptServer.instance.routes.get("/restart_comfy")
|
||
async def restart_comfy(self):
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False,
|
||
"message": "action is not allowed during workflow release/restore"}))
|
||
|
||
thread = threading.Thread(target=restart_comfy_commands)
|
||
thread.start()
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "message": "comfy will be restart in 5 seconds, "
|
||
"it's may take a few seconds"}))
|
||
|
||
|
||
@server.PromptServer.instance.routes.post("/sync_env")
|
||
async def sync_env(request):
|
||
logger.info(f"start to sync_env {request}")
|
||
try:
|
||
json_data = await request.json()
|
||
prepare_type = json_data['prepare_type'] if json_data and 'prepare_type' in json_data else 'inputs'
|
||
workflow_name = json_data['workflow_name'] if json_data and 'workflow_name' in json_data else os.getenv('WORKFLOW_NAME')
|
||
comfy_endpoint = get_endpoint_name_by_workflow_name(workflow_name)
|
||
thread = threading.Thread(target=sync_default_files, args=(comfy_endpoint, prepare_type))
|
||
thread.start()
|
||
# result = sync_default_files()
|
||
# logger.debug(f"sync result is :{result}")
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True}))
|
||
except Exception as e:
|
||
logger.info(f"error sync_env {e}")
|
||
pass
|
||
return web.Response(status=500, content_type='application/json', body=json.dumps({"result": False}))
|
||
|
||
|
||
@server.PromptServer.instance.routes.post("/change_env")
|
||
async def change_env(request):
|
||
logger.info(f"start to change_env {request}")
|
||
json_data = await request.json()
|
||
if DISABLE_AWS_PROXY in json_data and json_data[DISABLE_AWS_PROXY] is not None:
|
||
logger.info(
|
||
f"origin evn key DISABLE_AWS_PROXY is :{os.environ.get(DISABLE_AWS_PROXY)} {str(json_data[DISABLE_AWS_PROXY])}")
|
||
os.environ[DISABLE_AWS_PROXY] = str(json_data[DISABLE_AWS_PROXY])
|
||
logger.info(f"now evn key DISABLE_AWS_PROXY is :{os.environ.get(DISABLE_AWS_PROXY)}")
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps({"result": True}))
|
||
|
||
|
||
@server.PromptServer.instance.routes.get("/get_env")
|
||
async def get_env(request):
|
||
env = os.environ.get(DISABLE_AWS_PROXY, 'True')
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps({"env": env}))
|
||
|
||
|
||
@server.PromptServer.instance.routes.get("/get_env_new/{id}")
|
||
async def get_env_new(request):
|
||
logger.info(f"start to get_env_new {request}")
|
||
env_key = request.match_info.get("id", None)
|
||
logger.info("env_key is :" + str(env_key))
|
||
env_value = os.getenv(env_key)
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps({"env": env_value}))
|
||
|
||
|
||
@server.PromptServer.instance.routes.get("/check_is_master")
|
||
async def check_is_master(request):
|
||
is_master = is_master_process
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps({"master": is_master}))
|
||
|
||
|
||
def get_cloud_workflows(workflow_name):
|
||
response = requests.get(f"{api_url}/workflows", headers=headers, params={"limit": 1000})
|
||
if response.status_code != 200:
|
||
return None
|
||
|
||
data = response.json()['data']
|
||
workflows = data['workflows'] if (data and 'workflows' in data) else None
|
||
|
||
if not workflows:
|
||
return None
|
||
|
||
for workflow in workflows:
|
||
if workflow['name'] == workflow_name:
|
||
return workflow['payload_json']
|
||
return None
|
||
|
||
|
||
@server.PromptServer.instance.routes.get("/get_env_template/{id}")
|
||
async def get_env_template(request):
|
||
logger.info(f"start to get_env_template {request}")
|
||
template_id = request.match_info.get("id", None)
|
||
logger.info("template_id is :" + str(template_id))
|
||
workflow_name = os.getenv('WORKFLOW_NAME')
|
||
if template_id:
|
||
workflow_name = template_id
|
||
if workflow_name == 'default':
|
||
logger.info(f"workflow_name is {workflow_name}")
|
||
return web.Response(status=500, content_type='application/json', body=None)
|
||
prompt_json = get_cloud_workflows(workflow_name)
|
||
logger.debug(f"workflow_name is {workflow_name} and prompt_json is: {prompt_json}")
|
||
if not prompt_json:
|
||
logger.info(f"get_cloud_workflows none")
|
||
return web.Response(status=500, content_type='application/json', body=None)
|
||
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps(prompt_json))
|
||
|
||
|
||
def get_directory_size(directory):
|
||
total_size = 0
|
||
for dirpath, dirnames, filenames in os.walk(directory):
|
||
for filename in filenames:
|
||
filepath = os.path.join(dirpath, filename)
|
||
if not os.path.islink(filepath):
|
||
total_size += os.path.getsize(filepath)
|
||
return total_size
|
||
|
||
|
||
def dir_size(source_path: str):
|
||
total_size_bytes = get_directory_size(source_path)
|
||
source_size = round(total_size_bytes / (1024 ** 3), 2)
|
||
return str(source_size)
|
||
|
||
|
||
def async_release_workflow(workflow_name, payload_json):
|
||
start_time = time.time()
|
||
|
||
action_lock(workflow_name)
|
||
|
||
base_image = os.getenv('BASE_IMAGE')
|
||
subprocess.check_output(f"echo {workflow_name} > /container/image_target_name", shell=True)
|
||
subprocess.check_output(f"echo {base_image} > /container/image_base", shell=True)
|
||
|
||
cur_workflow_name = os.getenv('WORKFLOW_NAME')
|
||
source_path = f"/container/workflows/{cur_workflow_name}"
|
||
print(f"source_path is {source_path}")
|
||
|
||
s5cmd_sync_command = (f's5cmd sync '
|
||
f'--delete=true '
|
||
f'--exclude="*comfy.tar" '
|
||
f'--exclude="*.log" '
|
||
f'--exclude="*__pycache__*" '
|
||
f'--exclude="*/ComfyUI/output/*" '
|
||
f'--exclude="*/custom_nodes/ComfyUI-Manager/*" '
|
||
f'"{source_path}/*" '
|
||
f'"s3://{bucket_name}/comfy/workflows/{workflow_name}/"')
|
||
|
||
s5cmd_lock_command = (f'echo "lock" > lock && '
|
||
f's5cmd sync lock s3://{bucket_name}/comfy/workflows/{workflow_name}/lock')
|
||
|
||
logger.info(f"sync workflows files start {s5cmd_sync_command}")
|
||
|
||
subprocess.check_output(s5cmd_sync_command, shell=True)
|
||
subprocess.check_output(s5cmd_lock_command, shell=True)
|
||
|
||
end_time = time.time()
|
||
cost_time = end_time - start_time
|
||
image_hash = os.getenv('IMAGE_HASH')
|
||
image_uri = f"{image_hash}:{workflow_name}"
|
||
|
||
if isinstance(payload_json, dict):
|
||
payload_json = json.dumps(payload_json)
|
||
|
||
data = {
|
||
"payload_json": payload_json,
|
||
"image_uri": image_uri,
|
||
"name": workflow_name,
|
||
"size": dir_size(source_path),
|
||
}
|
||
get_response = requests.post(f"{api_url}/workflows", headers=headers, data=json.dumps(data))
|
||
response = get_response.json()
|
||
logger.info(f"release workflow response is {response}")
|
||
action_unlock()
|
||
print(f"release workflow cost time is {cost_time}")
|
||
|
||
|
||
@server.PromptServer.instance.routes.get("/lock")
|
||
async def get_lock_status(request):
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "lock": is_action_lock()}))
|
||
|
||
|
||
def async_release_env(workflow_name, payload_json, init_count: int, instance_type, auto_scale: bool, min_count: int, max_count: int):
|
||
start_time = time.time()
|
||
action_lock(workflow_name)
|
||
base_image = os.getenv('BASE_IMAGE')
|
||
subprocess.check_output(f"echo {workflow_name} > /container/image_target_name", shell=True)
|
||
subprocess.check_output(f"echo {base_image} > /container/image_base", shell=True)
|
||
|
||
cur_workflow_name = os.getenv('WORKFLOW_NAME')
|
||
source_path = f"/container/workflows/{cur_workflow_name}"
|
||
print(f"source_path is {source_path}")
|
||
|
||
s5cmd_sync_command = (f's5cmd sync '
|
||
f'--delete=true '
|
||
f'--exclude="*comfy.tar" '
|
||
f'--exclude="*.log" '
|
||
f'--exclude="*__pycache__*" '
|
||
f'--exclude="*/ComfyUI/output/*" '
|
||
f'--exclude="*/custom_nodes/ComfyUI-Manager/*" '
|
||
f'"{source_path}/*" '
|
||
f'"s3://{bucket_name}/comfy/workflows/{workflow_name}/"')
|
||
|
||
s5cmd_lock_command = (f'echo "lock" > lock && '
|
||
f's5cmd sync lock s3://{bucket_name}/comfy/workflows/{workflow_name}/lock')
|
||
|
||
logger.info(f"sync workflows files start {s5cmd_sync_command}")
|
||
|
||
subprocess.check_output(s5cmd_sync_command, shell=True)
|
||
subprocess.check_output(s5cmd_lock_command, shell=True)
|
||
|
||
end_time = time.time()
|
||
cost_time = end_time - start_time
|
||
image_hash = os.getenv('IMAGE_HASH')
|
||
image_uri = f"{image_hash}:{workflow_name}"
|
||
|
||
if isinstance(payload_json, dict):
|
||
payload_json = json.dumps(payload_json)
|
||
|
||
data = {
|
||
"payload_json": payload_json,
|
||
"image_uri": image_uri,
|
||
"name": workflow_name,
|
||
"size": dir_size(source_path),
|
||
}
|
||
get_response = requests.post(f"{api_url}/workflows", headers=headers, data=json.dumps(data))
|
||
response = get_response.json()
|
||
logger.info(f"release workflow response is {response}")
|
||
# TODO check response
|
||
endpoint_data = {
|
||
'workflow_name': workflow_name,
|
||
'endpoint_name': '',
|
||
'service_type': 'comfy',
|
||
'endpoint_type': 'Async',
|
||
'instance_type': instance_type,
|
||
'initial_instance_count': init_count,
|
||
'min_instance_number': min_count,
|
||
'max_instance_number': max_count,
|
||
'autoscaling_enabled': auto_scale,
|
||
'assign_to_roles': ['ec2'],
|
||
}
|
||
endpoint_response = requests.post(f"{api_url}/endpoints", headers=headers, data=json.dumps(endpoint_data))
|
||
logger.info(f"release env endpoint response is {endpoint_response}")
|
||
action_unlock()
|
||
print(f"release workflow cost time is {cost_time}")
|
||
|
||
|
||
@server.PromptServer.instance.routes.post("/release")
|
||
async def release_env(request):
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False,
|
||
"message": "action is not allowed during workflow release/restore"}))
|
||
|
||
if not is_master_process:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "only master can release workflow"}))
|
||
|
||
logger.info(f"start to release workflow {request}")
|
||
try:
|
||
json_data = await request.json()
|
||
if 'name' not in json_data or not json_data['name']:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"name is required"}))
|
||
|
||
workflow_name = json_data['name']
|
||
if workflow_name == 'default' or workflow_name == 'local':
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"{workflow_name} is not allowed"}))
|
||
|
||
if not re.match(r'^[A-Za-z][A-Za-z0-9_]*$', workflow_name):
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"{workflow_name} is invalid name"}))
|
||
|
||
payload_json = ''
|
||
|
||
if 'payload_json' in json_data:
|
||
payload_json = json_data['payload_json']
|
||
|
||
if check_workflow_exists(workflow_name):
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"{workflow_name} already exists"}))
|
||
|
||
if ('initCount' not in json_data or not json_data['initCount']
|
||
or not json_data['initCount'].isdigit() or int(json_data['initCount']) <= 0):
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"initCount is required"}))
|
||
if ('autoScale' not in json_data or not json_data['autoScale']):
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"autoScale is required"}))
|
||
if 'autoScale' in json_data and json_data['autoScale']:
|
||
if ('minCount' not in json_data or not json_data['minCount']
|
||
or not json_data['minCount'].isdigit() or int(json_data['minCount']) <= 0):
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"minCount is required"}))
|
||
if ('maxCount' not in json_data or not json_data['maxCount']
|
||
or not json_data['maxCount'].isdigit() or int(json_data['maxCount']) <= 0):
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"maxCount is required"}))
|
||
|
||
thread = threading.Thread(target=async_release_env, args=(workflow_name, payload_json, int(json_data['initCount']), json_data['instanceType'], bool(json_data['autoScale']), int(json_data['minCount']), int(json_data['maxCount'])))
|
||
thread.start()
|
||
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "message": "Pending to release workflow, "
|
||
"it's may take a few minutes"}))
|
||
except Exception as e:
|
||
logger.info(e)
|
||
action_unlock()
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": 'Release workflow failed'}))
|
||
|
||
@server.PromptServer.instance.routes.post("/workflows")
|
||
async def release_workflow(request):
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "action is not allowed during workflow release/restore"}))
|
||
|
||
if not is_master_process:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "only master can release workflow"}))
|
||
|
||
logger.info(f"start to release workflow {request}")
|
||
try:
|
||
json_data = await request.json()
|
||
if 'name' not in json_data or not json_data['name']:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"name is required"}))
|
||
|
||
workflow_name = json_data['name']
|
||
if workflow_name == 'default' or workflow_name == 'local':
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"{workflow_name} is not allowed"}))
|
||
|
||
if not re.match(r'^[A-Za-z][A-Za-z0-9_]*$', workflow_name):
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"{workflow_name} is invalid name"}))
|
||
|
||
payload_json = ''
|
||
|
||
if 'payload_json' in json_data:
|
||
payload_json = json_data['payload_json']
|
||
|
||
if check_workflow_exists(workflow_name):
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"{workflow_name} already exists"}))
|
||
|
||
thread = threading.Thread(target=async_release_workflow, args=(workflow_name, payload_json))
|
||
thread.start()
|
||
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "message": "Pending to release workflow, "
|
||
"it's may take a few minutes"}))
|
||
except Exception as e:
|
||
logger.info(e)
|
||
action_unlock()
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": 'Release workflow failed'}))
|
||
|
||
@server.PromptServer.instance.routes.delete("/workflows")
|
||
async def delete_workflow(request):
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "action is not allowed during workflow release/restore"}))
|
||
|
||
if not is_master_process:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "only master can delete workflows"}))
|
||
|
||
logger.info(f"start to delete workflows {request}")
|
||
try:
|
||
json_data = await request.json()
|
||
if 'name' not in json_data or not json_data['name']:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"name is required"}))
|
||
name = json_data['name']
|
||
|
||
if name == 'default' or name == 'local':
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"{name} is not allowed"}))
|
||
|
||
if os.getenv('WORKFLOW_NAME') == name:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "can not delete current workflow"}))
|
||
|
||
i = 0
|
||
start_n = 10000
|
||
while i < 30:
|
||
port = start_n + i
|
||
file_path = f"/container/comfy_{port}"
|
||
i = i + 1
|
||
if os.path.exists(file_path):
|
||
with open(file_path, 'r') as f:
|
||
content = f.read()
|
||
content = content.strip()
|
||
if content == name:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False,
|
||
"message": f"can not delete workflow "
|
||
f"because it is in use by {port}"}))
|
||
|
||
data = {
|
||
"workflow_name_list": [name],
|
||
}
|
||
response = requests.delete(f"{api_url}/workflows", headers=headers, data=json.dumps(data))
|
||
resp = response.json()
|
||
if response.status_code != 202:
|
||
return web.Response(status=200,
|
||
content_type='application/json',
|
||
body=json.dumps({"result": False, "message": resp['message']}))
|
||
|
||
os.system(f"rm -rf /container/workflows/{name}")
|
||
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "message": "Workflow will be deleted soon"}))
|
||
except Exception as e:
|
||
logger.info(e)
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": 'Delete workflow failed'}))
|
||
|
||
@server.PromptServer.instance.routes.put("/workflows")
|
||
async def switch_workflow(request):
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "action is not allowed during workflow release/restore"}))
|
||
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "switch is not allowed during workflow release/restore"}))
|
||
|
||
if os.path.exists("/container/s5cmd_lock"):
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "switch is not allowed during other's switch, "
|
||
"please try again later"}))
|
||
|
||
try:
|
||
json_data = await request.json()
|
||
if 'name' not in json_data or not json_data['name']:
|
||
raise ValueError("name is required")
|
||
|
||
workflow_name = json_data['name']
|
||
|
||
if workflow_name == os.getenv('WORKFLOW_NAME'):
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "workflow is already in use"}))
|
||
|
||
if workflow_name == 'default' and not is_master_process:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "slave can not use default workflow "
|
||
"after initial"}))
|
||
|
||
if workflow_name != 'default':
|
||
if not check_workflow_exists(workflow_name):
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"{workflow_name} not exists"}))
|
||
|
||
name_file = os.getenv('WORKFLOW_NAME_FILE')
|
||
|
||
# don‘t move used for sync automic
|
||
os.environ['COMFY_ENDPOINT'] = get_endpoint_name_by_workflow_name(workflow_name)
|
||
|
||
subprocess.check_output(f"echo {workflow_name} > {name_file}", shell=True)
|
||
|
||
thread = threading.Thread(target=kill_after_seconds)
|
||
thread.start()
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "message": "Comfy will be switch in 2 seconds, "
|
||
"it's may take a few minutes"}))
|
||
except Exception as e:
|
||
logger.info(e)
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": 'Switch workflow failed'}))
|
||
|
||
@server.PromptServer.instance.routes.get("/workflows")
|
||
async def get_workflows(request):
|
||
try:
|
||
workflow_name = os.getenv('WORKFLOW_NAME')
|
||
|
||
response = requests.get(f"{api_url}/workflows", headers=headers, params={"limit": 1000})
|
||
if response.status_code != 200:
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": 'Get workflows failed'}))
|
||
|
||
data = response.json()['data']
|
||
workflows = data['workflows']
|
||
|
||
list = []
|
||
|
||
if is_master_process:
|
||
list.append({
|
||
"name": 'default',
|
||
"size": dir_size(f"/container/workflows/default"),
|
||
"status": 'Enabled',
|
||
"payload_json": '',
|
||
"in_use": 'default' == workflow_name
|
||
})
|
||
|
||
for workflow in workflows:
|
||
list.append({
|
||
"name": workflow['name'],
|
||
"size": workflow['size'],
|
||
"status": workflow['status'],
|
||
"payload_json": workflow['payload_json'],
|
||
"in_use": workflow['name'] == workflow_name
|
||
})
|
||
|
||
data['workflows'] = list
|
||
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "data": data}))
|
||
except Exception as e:
|
||
logger.info(e)
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": 'list workflows failed'}))
|
||
|
||
@server.PromptServer.instance.routes.get("/schemas")
|
||
async def get_schemas(request):
|
||
try:
|
||
|
||
limit = 10
|
||
if 'limit' in request.query and request.query['limit']:
|
||
limit = int(request.query['limit'])
|
||
|
||
exclusive_start_key = None
|
||
if 'exclusive_start_key' in request.query and request.query['exclusive_start_key']:
|
||
exclusive_start_key = request.query['exclusive_start_key']
|
||
|
||
params={
|
||
"limit": limit,
|
||
"exclusive_start_key": exclusive_start_key,
|
||
}
|
||
|
||
response = requests.get(f"{api_url}/schemas", headers=headers, params=params)
|
||
|
||
if response.status_code != 200:
|
||
resp = response.json()
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f'List schemas failed: {resp["message"]}'}))
|
||
|
||
data = response.json()['data']
|
||
schemas = data['schemas']
|
||
|
||
list = []
|
||
|
||
for schema in schemas:
|
||
if not is_master_process and not schema['workflow']:
|
||
continue
|
||
|
||
list.append({
|
||
"name": schema['name'],
|
||
"workflow": schema['workflow'],
|
||
"payload": schema['payload'],
|
||
"create_time": schema['create_time'],
|
||
})
|
||
|
||
data['schemas'] = list
|
||
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "data": data}))
|
||
except Exception as e:
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f'List schemas failed: {e}'}))
|
||
|
||
@server.PromptServer.instance.routes.post("/schemas")
|
||
async def create_schema(request):
|
||
if not is_master_process:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "only master can create schema"}))
|
||
|
||
try:
|
||
json_data = await request.json()
|
||
if 'name' not in json_data or not json_data['name']:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"name is required"}))
|
||
|
||
if 'payload' not in json_data or not json_data['payload']:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"payload is required"}))
|
||
|
||
name = json_data['name']
|
||
payload = json_data['payload']
|
||
workflow_name = ''
|
||
|
||
if 'workflow' in json_data and json_data['workflow']:
|
||
workflow_name = json_data['workflow']
|
||
|
||
data = {
|
||
"payload": payload,
|
||
"name": name,
|
||
"workflow": workflow_name,
|
||
}
|
||
get_response = requests.post(f"{api_url}/schemas", headers=headers, data=json.dumps(data))
|
||
response = get_response.json()
|
||
if get_response.status_code != 200:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": response['message']}))
|
||
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "message": "Created schema"}))
|
||
except Exception as e:
|
||
logger.info(e)
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": 'Create schema failed'}))
|
||
|
||
@server.PromptServer.instance.routes.delete("/schemas")
|
||
async def delete_schemas(request):
|
||
if not is_master_process:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "only master can delete schemas"}))
|
||
|
||
try:
|
||
json_data = await request.json()
|
||
if 'schema_name_list' not in json_data or not json_data['schema_name_list']:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f"schema_name_list is required"}))
|
||
schema_name_list = json_data['schema_name_list']
|
||
|
||
data = {
|
||
"schema_name_list": schema_name_list,
|
||
}
|
||
response = requests.delete(f"{api_url}/schemas", headers=headers, data=json.dumps(data))
|
||
|
||
if response.status_code != 204:
|
||
resp = response.json()
|
||
return web.Response(status=200,
|
||
content_type='application/json',
|
||
body=json.dumps({"result": False, "message": resp['message']}))
|
||
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "message": "Schema deleted"}))
|
||
except Exception as e:
|
||
logger.info(e)
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": 'Delete schemas failed'}))
|
||
|
||
@server.PromptServer.instance.routes.put("/schemas")
|
||
async def update_schema(request):
|
||
if not is_master_process:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "only master can update schema"}))
|
||
|
||
try:
|
||
json_data = await request.json()
|
||
if 'name' not in json_data or not json_data['name']:
|
||
raise ValueError("name is required")
|
||
|
||
if 'payload' not in json_data or not json_data['payload']:
|
||
raise ValueError("payload is required")
|
||
|
||
workflow_name = ''
|
||
|
||
if 'workflow' in json_data and json_data['workflow']:
|
||
workflow_name = json_data['workflow']
|
||
|
||
name = json_data['name']
|
||
payload = json_data['payload']
|
||
|
||
data = {
|
||
"workflow": workflow_name,
|
||
"payload": payload,
|
||
}
|
||
get_response = requests.put(f"{api_url}/schemas/{name}", headers=headers, data=json.dumps(data))
|
||
if get_response.status_code != 200:
|
||
response = get_response.json()
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": response['message']}))
|
||
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "message": "Updated schema"}))
|
||
except Exception as e:
|
||
logger.info(e)
|
||
return web.Response(status=500, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": f'Update schema failed: {e}'}))
|
||
|
||
def check_workflow_exists(name: str):
|
||
get_response = requests.get(f"{api_url}/workflows/{name}", headers=headers)
|
||
return get_response.status_code == 200
|
||
|
||
|
||
def restart_docker_commands():
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "action is not allowed during release workflow"}))
|
||
|
||
subprocess.run(["sleep", "5"])
|
||
subprocess.run(["pkill", "-f", "python3"])
|
||
|
||
|
||
def restart_comfy_commands():
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "action is not allowed during release workflow"}))
|
||
try:
|
||
sys.stdout.close_log()
|
||
except Exception as e:
|
||
logger.info(f"error restart {e}")
|
||
pass
|
||
return os.execv(sys.executable, [sys.executable] + sys.argv)
|
||
|
||
|
||
def restart_response():
|
||
thread = threading.Thread(target=restart_docker_commands)
|
||
thread.start()
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "message": "comfy will be restart in 5 seconds"}))
|
||
|
||
def kill_after_seconds():
|
||
subprocess.run(["sleep", "2"])
|
||
subprocess.run(["pkill", "-f", "python3"])
|
||
|
||
def restore_workflow():
|
||
action_lock("restore")
|
||
subprocess.run(["sleep", "2"])
|
||
os.system("rm -rf /container/workflows/default")
|
||
action_unlock()
|
||
subprocess.run(["pkill", "-f", "python3"])
|
||
|
||
@server.PromptServer.instance.routes.post("/restore")
|
||
async def release_rebuild_workflow(request):
|
||
if is_action_lock():
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps(
|
||
{"result": False, "message": "action is not allowed during workflow release/restore"}))
|
||
|
||
if os.getenv('WORKFLOW_NAME') != 'default':
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "only default workflow can be restored"}))
|
||
|
||
if not is_master_process:
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": False, "message": "only master can restore comfy"}))
|
||
|
||
thread = threading.Thread(target=restore_workflow)
|
||
thread.start()
|
||
return web.Response(status=200, content_type='application/json',
|
||
body=json.dumps({"result": True, "message": "Comfy will be start restore in 2 seconds, "
|
||
"it's may take a few minutes"}))
|
||
|
||
if is_on_sagemaker:
|
||
|
||
global need_sync
|
||
global prompt_id
|
||
global executing
|
||
executing = False
|
||
|
||
global reboot
|
||
reboot = False
|
||
|
||
global last_call_time
|
||
last_call_time = None
|
||
global gc_triggered
|
||
gc_triggered = False
|
||
|
||
REGION = os.environ.get('AWS_REGION')
|
||
BUCKET = os.environ.get('S3_BUCKET_NAME')
|
||
QUEUE_URL = os.environ.get('COMFY_QUEUE_URL')
|
||
|
||
GEN_INSTANCE_ID = os.environ.get('ENDPOINT_INSTANCE_ID') if 'ENDPOINT_INSTANCE_ID' in os.environ and os.environ.get(
|
||
'ENDPOINT_INSTANCE_ID') else str(uuid.uuid4())
|
||
ENDPOINT_NAME = os.environ.get('ENDPOINT_NAME')
|
||
ENDPOINT_ID = os.environ.get('ENDPOINT_ID')
|
||
|
||
INSTANCE_MONITOR_TABLE_NAME = os.environ.get('COMFY_INSTANCE_MONITOR_TABLE')
|
||
SYNC_TABLE_NAME = os.environ.get('COMFY_SYNC_TABLE')
|
||
|
||
dynamodb = boto3.resource('dynamodb', region_name=REGION)
|
||
sync_table = dynamodb.Table(SYNC_TABLE_NAME)
|
||
instance_monitor_table = dynamodb.Table(INSTANCE_MONITOR_TABLE_NAME)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.INFO)
|
||
|
||
ROOT_PATH = '/home/ubuntu/ComfyUI'
|
||
sqs_client = boto3.client('sqs', region_name=REGION)
|
||
|
||
GC_WAIT_TIME = 1800
|
||
|
||
|
||
def print_env():
|
||
for key, value in os.environ.items():
|
||
logger.info(f"{key}: {value}")
|
||
|
||
|
||
@dataclass
|
||
class ComfyResponse:
|
||
statusCode: int
|
||
message: str
|
||
body: Optional[dict]
|
||
|
||
|
||
def ok(body: dict):
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps(body))
|
||
|
||
|
||
def error(body: dict):
|
||
# TODO 500 -》200 because of need resp anyway not exception
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps(body))
|
||
|
||
|
||
def sen_sqs_msg(message_body, prompt_id_key):
|
||
response = sqs_client.send_message(
|
||
QueueUrl=QUEUE_URL,
|
||
MessageBody=json.dumps(message_body),
|
||
MessageGroupId=prompt_id_key
|
||
)
|
||
message_id = response['MessageId']
|
||
return message_id
|
||
|
||
|
||
def sen_finish_sqs_msg(prompt_id_key):
|
||
global need_sync
|
||
# logger.info(f"sen_finish_sqs_msg start... {need_sync},{prompt_id_key}")
|
||
if need_sync and QUEUE_URL and REGION:
|
||
message_body = {'prompt_id': prompt_id_key, 'event': 'finish',
|
||
'data': {"node": None, "prompt_id": prompt_id_key},
|
||
'sid': None}
|
||
message_id = sen_sqs_msg(message_body, prompt_id_key)
|
||
logger.info(f"finish message sent {message_id}")
|
||
|
||
|
||
async def prepare_comfy_env(sync_item: dict):
|
||
try:
|
||
request_id = sync_item['request_id']
|
||
logger.info(f"prepare_environment start sync_item:{sync_item}")
|
||
prepare_type = sync_item['prepare_type']
|
||
rlt = True
|
||
if prepare_type in ['default', 'models']:
|
||
sync_models_rlt = sync_s3_files_or_folders_to_local(f'{request_id}/models/*', f'{ROOT_PATH}/models',
|
||
False)
|
||
if not sync_models_rlt:
|
||
rlt = False
|
||
if prepare_type in ['default', 'inputs']:
|
||
sync_inputs_rlt = sync_s3_files_or_folders_to_local(f'{request_id}/input/*', f'{ROOT_PATH}/input',
|
||
False)
|
||
if not sync_inputs_rlt:
|
||
rlt = False
|
||
if prepare_type in ['nodes']:
|
||
sync_nodes_rlt = sync_s3_files_or_folders_to_local(f'{request_id}/custom_nodes/*',
|
||
f'{ROOT_PATH}/custom_nodes', True)
|
||
if not sync_nodes_rlt:
|
||
rlt = False
|
||
if prepare_type == 'custom':
|
||
sync_source_path = sync_item['s3_source_path']
|
||
local_target_path = sync_item['local_target_path']
|
||
if not sync_source_path or not local_target_path:
|
||
logger.info("s3_source_path and local_target_path should not be empty")
|
||
else:
|
||
sync_rlt = sync_s3_files_or_folders_to_local(sync_source_path,
|
||
f'{ROOT_PATH}/{local_target_path}', False)
|
||
if not sync_rlt:
|
||
rlt = False
|
||
elif prepare_type == 'other':
|
||
sync_script = sync_item['sync_script']
|
||
logger.info(f"sync_script {sync_script}")
|
||
# sync_script.startswith('s5cmd') 不允许
|
||
try:
|
||
if sync_script and (
|
||
sync_script.startswith("cat") or sync_script.startswith("os.environ")
|
||
or sync_script.startswith("print") or sync_script.startswith("ls ")
|
||
or sync_script.startswith("du ")
|
||
# or sync_script.startswith("python3 -m pip") or sync_script.startswith("python -m pip")
|
||
# or sync_script.startswith("pip install") or sync_script.startswith("apt")
|
||
# or sync_script.startswith("curl") or sync_script.startswith("wget")
|
||
# or sync_script.startswith("env") or sync_script.startswith("source")
|
||
# or sync_script.startswith("sudo chmod") or sync_script.startswith("chmod")
|
||
# or sync_script.startswith("/home/ubuntu/ComfyUI/venv/bin/python")
|
||
):
|
||
os.system(sync_script)
|
||
elif sync_script and (sync_script.startswith("export ") and len(sync_script.split(" ")) > 2):
|
||
sync_script_key = sync_script.split(" ")[1]
|
||
sync_script_value = sync_script.split(" ")[2]
|
||
os.environ[sync_script_key] = sync_script_value
|
||
logger.info(os.environ.get(sync_script_key))
|
||
except Exception as e:
|
||
logger.error(f"Exception while execute sync_scripts : {sync_script}")
|
||
rlt = False
|
||
need_reboot = True if ('need_reboot' in sync_item and sync_item['need_reboot']
|
||
and str(sync_item['need_reboot']).lower() == 'true') else False
|
||
global reboot
|
||
reboot = need_reboot
|
||
if need_reboot:
|
||
os.environ['NEED_REBOOT'] = 'true'
|
||
else:
|
||
os.environ['NEED_REBOOT'] = 'false'
|
||
logger.info("prepare_environment end")
|
||
os.environ['LAST_SYNC_REQUEST_ID'] = sync_item['request_id']
|
||
os.environ['LAST_SYNC_REQUEST_TIME'] = str(sync_item['request_time'])
|
||
return rlt
|
||
except Exception as e:
|
||
return False
|
||
|
||
|
||
def sync_s3_files_or_folders_to_local(s3_path, local_path, need_un_tar):
|
||
logger.info("sync_s3_models_or_inputs_to_local start")
|
||
# s5cmd_command = f'{ROOT_PATH}/tools/s5cmd sync "s3://{bucket_name}/{s3_path}/*" "{local_path}/"'
|
||
if need_un_tar:
|
||
s5cmd_command = f's5cmd sync "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
|
||
else:
|
||
s5cmd_command = f's5cmd sync --delete=true "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
|
||
# s5cmd_command = f's5cmd sync --delete=true "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
|
||
# s5cmd_command = f's5cmd sync "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
|
||
try:
|
||
logger.info(s5cmd_command)
|
||
os.system(s5cmd_command)
|
||
logger.info(f'Files copied from "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" to "{local_path}/"')
|
||
if need_un_tar:
|
||
for filename in os.listdir(local_path):
|
||
if filename.endswith(".tar.gz"):
|
||
tar_filepath = os.path.join(local_path, filename)
|
||
# extract_path = os.path.splitext(os.path.splitext(tar_filepath)[0])[0]
|
||
# os.makedirs(extract_path, exist_ok=True)
|
||
# logger.info(f'Extracting extract_path is {extract_path}')
|
||
|
||
with tarfile.open(tar_filepath, "r:gz") as tar:
|
||
for member in tar.getmembers():
|
||
tar.extract(member, path=local_path)
|
||
os.remove(tar_filepath)
|
||
logger.info(f'File {tar_filepath} extracted and removed')
|
||
return True
|
||
except Exception as e:
|
||
logger.info(f"Error executing s5cmd command: {e}")
|
||
return False
|
||
|
||
|
||
def sync_local_outputs_to_s3(s3_path, local_path):
|
||
logger.info("sync_local_outputs_to_s3 start")
|
||
s5cmd_command = f's5cmd sync "{local_path}/*" "s3://{BUCKET}/comfy/{s3_path}/" '
|
||
try:
|
||
logger.info(s5cmd_command)
|
||
os.system(s5cmd_command)
|
||
logger.info(f'Files copied local to "s3://{BUCKET}/comfy/{s3_path}/" to "{local_path}/"')
|
||
clean_cmd = f'rm -rf {local_path}'
|
||
os.system(clean_cmd)
|
||
logger.info(f'Files removed from local {local_path}')
|
||
except Exception as e:
|
||
logger.info(f"Error executing s5cmd command: {e}")
|
||
|
||
|
||
def sync_local_outputs_to_base64(local_path):
|
||
logger.info("sync_local_outputs_to_base64 start")
|
||
try:
|
||
result = {}
|
||
for root, dirs, files in os.walk(local_path):
|
||
for file in files:
|
||
file_path = os.path.join(root, file)
|
||
with open(file_path, "rb") as f:
|
||
file_content = f.read()
|
||
base64_content = base64.b64encode(file_content).decode('utf-8')
|
||
result[file] = base64_content
|
||
clean_cmd = f'rm -rf {local_path}'
|
||
os.system(clean_cmd)
|
||
logger.info(f'Files removed from local {local_path}')
|
||
return result
|
||
except Exception as e:
|
||
logger.info(f"Error executing s5cmd command: {e}")
|
||
return {}
|
||
|
||
|
||
@server.PromptServer.instance.routes.post("/execute_proxy")
|
||
async def execute_proxy(request):
|
||
logger.info("start to execute_proxy inside")
|
||
json_data = await request.json()
|
||
if 'out_path' in json_data and json_data['out_path'] is not None:
|
||
out_path = json_data['out_path']
|
||
else:
|
||
out_path = None
|
||
logger.info(f"invocations start json_data:{json_data}")
|
||
global need_sync
|
||
need_sync = json_data["need_sync"]
|
||
global prompt_id
|
||
prompt_id = json_data["prompt_id"]
|
||
try:
|
||
global executing
|
||
if executing is True:
|
||
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
|
||
"message": "the environment is not ready valid[0] is false, need to resync"}
|
||
sen_finish_sqs_msg(prompt_id)
|
||
return error(resp)
|
||
executing = True
|
||
logger.info(
|
||
f'bucket_name: {BUCKET}, region: {REGION}')
|
||
if ('need_prepare' in json_data and json_data['need_prepare']
|
||
and 'prepare_props' in json_data and json_data['prepare_props']):
|
||
sync_already = await prepare_comfy_env(json_data['prepare_props'])
|
||
if not sync_already:
|
||
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
|
||
"message": "the environment is not ready with sync"}
|
||
executing = False
|
||
sen_finish_sqs_msg(prompt_id)
|
||
return error(resp)
|
||
server_instance = server.PromptServer.instance
|
||
if "number" in json_data:
|
||
number = float(json_data['number'])
|
||
server_instance.number = number
|
||
else:
|
||
number = server_instance.number
|
||
if "front" in json_data:
|
||
if json_data['front']:
|
||
number = -number
|
||
server_instance.number += 1
|
||
valid = execution.validate_prompt(json_data['prompt'])
|
||
logger.info(f"Validating prompt result is {valid}")
|
||
if not valid[0]:
|
||
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
|
||
"message": "the environment is not ready valid[0] is false, need to resync"}
|
||
executing = False
|
||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||
sen_finish_sqs_msg(prompt_id)
|
||
return error(resp)
|
||
# if len(valid) == 4 and len(valid[3]) > 0:
|
||
# logger.info(f"Validating prompt error there is something error because of :valid: {valid}")
|
||
# resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
|
||
# "message": f"the valid is error, need to resync or check the workflow :{valid}"}
|
||
# executing = False
|
||
# return error(resp)
|
||
extra_data = {}
|
||
client_id = ''
|
||
if "extra_data" in json_data:
|
||
extra_data = json_data["extra_data"]
|
||
if 'client_id' in extra_data and extra_data['client_id']:
|
||
client_id = extra_data['client_id']
|
||
if "client_id" in json_data and json_data["client_id"]:
|
||
extra_data["client_id"] = json_data["client_id"]
|
||
client_id = json_data["client_id"]
|
||
|
||
server_instance.client_id = client_id
|
||
|
||
prompt_id = json_data['prompt_id']
|
||
server_instance.last_prompt_id = prompt_id
|
||
e = execution.PromptExecutor(server_instance)
|
||
outputs_to_execute = valid[2]
|
||
e.execute(json_data['prompt'], prompt_id, extra_data, outputs_to_execute)
|
||
|
||
s3_out_path = f'output/{prompt_id}/{out_path}' if out_path is not None else f'output/{prompt_id}'
|
||
s3_temp_path = f'temp/{prompt_id}/{out_path}' if out_path is not None else f'temp/{prompt_id}'
|
||
local_out_path = f'{ROOT_PATH}/output/{out_path}' if out_path is not None else f'{ROOT_PATH}/output'
|
||
local_temp_path = f'{ROOT_PATH}/temp/{out_path}' if out_path is not None else f'{ROOT_PATH}/temp'
|
||
|
||
logger.info(
|
||
f"s3_out_path is {s3_out_path} and s3_temp_path is {s3_temp_path} and local_out_path is {local_out_path} and local_temp_path is {local_temp_path}")
|
||
|
||
sync_local_outputs_to_s3(s3_out_path, local_out_path)
|
||
sync_local_outputs_to_s3(s3_temp_path, local_temp_path)
|
||
|
||
response_body = {
|
||
"prompt_id": prompt_id,
|
||
"instance_id": GEN_INSTANCE_ID,
|
||
"status": "success",
|
||
"output_path": f's3://{BUCKET}/comfy/{s3_out_path}',
|
||
"temp_path": f's3://{BUCKET}/comfy/{s3_temp_path}',
|
||
}
|
||
sen_finish_sqs_msg(prompt_id)
|
||
logger.info(f"execute inference response is {response_body}")
|
||
executing = False
|
||
return ok(response_body)
|
||
except Exception as ecp:
|
||
logger.info(f"exception occurred {ecp}")
|
||
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
|
||
"message": f"exception occurred {ecp}"}
|
||
executing = False
|
||
return error(resp)
|
||
finally:
|
||
logger.info(f"gc check: {time.time()}")
|
||
try:
|
||
global last_call_time, gc_triggered
|
||
gc_triggered = False
|
||
if last_call_time is None:
|
||
logger.info(f"gc check last time is NONE")
|
||
last_call_time = time.time()
|
||
else:
|
||
if time.time() - last_call_time > GC_WAIT_TIME:
|
||
if not gc_triggered:
|
||
logger.info(f"gc start: {time.time()} - {last_call_time}")
|
||
e.reset()
|
||
comfy.model_management.cleanup_models()
|
||
gc.collect()
|
||
comfy.model_management.soft_empty_cache()
|
||
gc_triggered = True
|
||
logger.info(f"gc end: {time.time()} - {last_call_time}")
|
||
last_call_time = time.time()
|
||
else:
|
||
last_call_time = time.time()
|
||
logger.info(f"gc check end: {time.time()}")
|
||
except Exception as e:
|
||
logger.info(f"gc error: {e}")
|
||
|
||
|
||
def get_last_ddb_sync_record():
|
||
sync_response = sync_table.query(
|
||
KeyConditionExpression=Key('endpoint_name').eq(ENDPOINT_NAME),
|
||
Limit=1,
|
||
ScanIndexForward=False
|
||
)
|
||
latest_sync_record = sync_response['Items'][0] if ('Items' in sync_response
|
||
and len(sync_response['Items']) > 0) else None
|
||
if latest_sync_record:
|
||
logger.info(f"latest_sync_record is:{latest_sync_record}")
|
||
return latest_sync_record
|
||
|
||
logger.info("no latest_sync_record found")
|
||
return None
|
||
|
||
|
||
def get_latest_ddb_instance_monitor_record():
|
||
key_condition_expression = ('endpoint_name = :endpoint_name_val '
|
||
'AND gen_instance_id = :gen_instance_id_val')
|
||
expression_attribute_values = {
|
||
':endpoint_name_val': ENDPOINT_NAME,
|
||
':gen_instance_id_val': GEN_INSTANCE_ID
|
||
}
|
||
instance_monitor_response = instance_monitor_table.query(
|
||
KeyConditionExpression=key_condition_expression,
|
||
ExpressionAttributeValues=expression_attribute_values
|
||
)
|
||
instance_monitor_record = instance_monitor_response['Items'][0] \
|
||
if ('Items' in instance_monitor_response and len(instance_monitor_response['Items']) > 0) else None
|
||
|
||
if instance_monitor_record:
|
||
logger.info(f"instance_monitor_record is {instance_monitor_record}")
|
||
return instance_monitor_record
|
||
|
||
logger.info("no instance_monitor_record found")
|
||
return None
|
||
|
||
|
||
def save_sync_instance_monitor(last_sync_request_id: str, sync_status: str):
|
||
item = {
|
||
'endpoint_id': ENDPOINT_ID,
|
||
'endpoint_name': ENDPOINT_NAME,
|
||
'gen_instance_id': GEN_INSTANCE_ID,
|
||
'sync_status': sync_status,
|
||
'last_sync_request_id': last_sync_request_id,
|
||
'last_sync_time': datetime.datetime.now().isoformat(),
|
||
'sync_list': [],
|
||
'create_time': datetime.datetime.now().isoformat(),
|
||
'last_heartbeat_time': datetime.datetime.now().isoformat()
|
||
}
|
||
save_resp = instance_monitor_table.put_item(Item=item)
|
||
logger.info(f"save instance item {save_resp}")
|
||
return save_resp
|
||
|
||
|
||
def update_sync_instance_monitor(instance_monitor_record):
|
||
update_expression = ("SET sync_status = :new_sync_status, last_sync_request_id = :sync_request_id, "
|
||
"sync_list = :sync_list, last_sync_time = :sync_time, last_heartbeat_time = :heartbeat_time")
|
||
expression_attribute_values = {
|
||
":new_sync_status": instance_monitor_record['sync_status'],
|
||
":sync_request_id": instance_monitor_record['last_sync_request_id'],
|
||
":sync_list": instance_monitor_record['sync_list'],
|
||
":sync_time": datetime.datetime.now().isoformat(),
|
||
":heartbeat_time": datetime.datetime.now().isoformat(),
|
||
}
|
||
|
||
response = instance_monitor_table.update_item(
|
||
Key={'endpoint_name': ENDPOINT_NAME,
|
||
'gen_instance_id': GEN_INSTANCE_ID},
|
||
UpdateExpression=update_expression,
|
||
ExpressionAttributeValues=expression_attribute_values
|
||
)
|
||
logger.info(f"update_sync_instance_monitor :{response}")
|
||
return response
|
||
|
||
|
||
def sync_instance_monitor_status(need_save: bool):
|
||
try:
|
||
logger.info(f"sync_instance_monitor_status {datetime.datetime.now()}")
|
||
if need_save:
|
||
save_sync_instance_monitor('', 'init')
|
||
else:
|
||
update_expression = ("SET last_heartbeat_time = :heartbeat_time")
|
||
expression_attribute_values = {
|
||
":heartbeat_time": datetime.datetime.now().isoformat(),
|
||
}
|
||
instance_monitor_table.update_item(
|
||
Key={'endpoint_name': ENDPOINT_NAME,
|
||
'gen_instance_id': GEN_INSTANCE_ID},
|
||
UpdateExpression=update_expression,
|
||
ExpressionAttributeValues=expression_attribute_values
|
||
)
|
||
except Exception as e:
|
||
logger.info(f"sync_instance_monitor_status error :{e}")
|
||
|
||
|
||
@server.PromptServer.instance.routes.post("/reboot")
|
||
async def restart(self):
|
||
logger.debug(f"start to reboot!!!!!!!! {self}")
|
||
global executing
|
||
if executing is True:
|
||
logger.info(f"other inference doing cannot reboot!!!!!!!!")
|
||
return ok({"message": "other inference doing cannot reboot"})
|
||
need_reboot = os.environ.get('NEED_REBOOT')
|
||
if need_reboot and need_reboot.lower() != 'true':
|
||
logger.info("no need to reboot by os")
|
||
return ok({"message": "no need to reboot by os"})
|
||
global reboot
|
||
if reboot is False:
|
||
logger.info("no need to reboot by global constant")
|
||
return ok({"message": "no need to reboot by constant"})
|
||
|
||
logger.debug("rebooting !!!!!!!!")
|
||
try:
|
||
sys.stdout.close_log()
|
||
except Exception as e:
|
||
logger.info(f"error reboot!!!!!!!! {e}")
|
||
pass
|
||
return os.execv(sys.executable, [sys.executable] + sys.argv)
|
||
|
||
|
||
# must be sync invoke and use the env to check
|
||
@server.PromptServer.instance.routes.post("/sync_instance")
|
||
async def sync_instance(request):
|
||
if not BUCKET:
|
||
logger.error("No bucket provided ,wait and try again")
|
||
resp = {"status": "success", "message": "syncing"}
|
||
return ok(resp)
|
||
|
||
if 'ALREADY_SYNC' in os.environ and os.environ.get('ALREADY_SYNC').lower() == 'false':
|
||
resp = {"status": "success", "message": "syncing"}
|
||
logger.error("other process doing ,wait and try again")
|
||
return ok(resp)
|
||
|
||
os.environ['ALREADY_SYNC'] = 'false'
|
||
logger.info(f"sync_instance start !! {datetime.datetime.now().isoformat()} {request}")
|
||
try:
|
||
last_sync_record = get_last_ddb_sync_record()
|
||
if not last_sync_record:
|
||
logger.info("no last sync record found do not need sync")
|
||
sync_instance_monitor_status(True)
|
||
resp = {"status": "success", "message": "no sync"}
|
||
os.environ['ALREADY_SYNC'] = 'true'
|
||
return ok(resp)
|
||
|
||
if ('request_id' in last_sync_record and last_sync_record['request_id']
|
||
and os.environ.get('LAST_SYNC_REQUEST_ID')
|
||
and os.environ.get('LAST_SYNC_REQUEST_ID') == last_sync_record['request_id']
|
||
and os.environ.get('LAST_SYNC_REQUEST_TIME')
|
||
and os.environ.get('LAST_SYNC_REQUEST_TIME') == str(last_sync_record['request_time'])):
|
||
logger.info("last sync record already sync by os check")
|
||
sync_instance_monitor_status(False)
|
||
resp = {"status": "success", "message": "no sync env"}
|
||
os.environ['ALREADY_SYNC'] = 'true'
|
||
return ok(resp)
|
||
|
||
instance_monitor_record = get_latest_ddb_instance_monitor_record()
|
||
if not instance_monitor_record:
|
||
sync_already = await prepare_comfy_env(last_sync_record)
|
||
if sync_already:
|
||
logger.info("should init prepare instance_monitor_record")
|
||
sync_status = 'success' if sync_already else 'failed'
|
||
save_sync_instance_monitor(last_sync_record['request_id'], sync_status)
|
||
else:
|
||
sync_instance_monitor_status(False)
|
||
else:
|
||
if ('last_sync_request_id' in instance_monitor_record
|
||
and instance_monitor_record['last_sync_request_id']
|
||
and instance_monitor_record['last_sync_request_id'] == last_sync_record['request_id']
|
||
and instance_monitor_record['sync_status']
|
||
and instance_monitor_record['sync_status'] == 'success'
|
||
and os.environ.get('LAST_SYNC_REQUEST_TIME')
|
||
and os.environ.get('LAST_SYNC_REQUEST_TIME') == str(last_sync_record['request_time'])):
|
||
logger.info("last sync record already sync")
|
||
sync_instance_monitor_status(False)
|
||
resp = {"status": "success", "message": "no sync ddb"}
|
||
os.environ['ALREADY_SYNC'] = 'true'
|
||
return ok(resp)
|
||
|
||
sync_already = await prepare_comfy_env(last_sync_record)
|
||
instance_monitor_record['sync_status'] = 'success' if sync_already else 'failed'
|
||
instance_monitor_record['last_sync_request_id'] = last_sync_record['request_id']
|
||
sync_list = instance_monitor_record['sync_list'] if ('sync_list' in instance_monitor_record
|
||
and instance_monitor_record['sync_list']) else []
|
||
sync_list.append(last_sync_record['request_id'])
|
||
|
||
instance_monitor_record['sync_list'] = sync_list
|
||
logger.info("should update prepare instance_monitor_record")
|
||
update_sync_instance_monitor(instance_monitor_record)
|
||
os.environ['ALREADY_SYNC'] = 'true'
|
||
resp = {"status": "success", "message": "sync"}
|
||
return ok(resp)
|
||
except Exception as e:
|
||
logger.info("exception occurred", e)
|
||
os.environ['ALREADY_SYNC'] = 'true'
|
||
resp = {"status": "fail", "message": "sync"}
|
||
return error(resp)
|
||
|
||
|
||
def validate_prompt_proxy(func):
|
||
def wrapper(*args, **kwargs):
|
||
logger.info("validate_prompt_proxy start...")
|
||
result = func(*args, **kwargs)
|
||
logger.info("validate_prompt_proxy end...")
|
||
return result
|
||
|
||
return wrapper
|
||
|
||
|
||
execution.validate_prompt = validate_prompt_proxy(execution.validate_prompt)
|
||
|
||
|
||
def send_sync_proxy(func):
|
||
def wrapper(*args, **kwargs):
|
||
logger.debug(f"Sending sync request!!!!!!! {args}")
|
||
global need_sync
|
||
global prompt_id
|
||
logger.info(f"send_sync_proxy start... {need_sync},{prompt_id} {args}")
|
||
func(*args, **kwargs)
|
||
if need_sync and QUEUE_URL and REGION:
|
||
logger.debug(f"send_sync_proxy params... {QUEUE_URL},{REGION},{need_sync},{prompt_id}")
|
||
event = args[1]
|
||
data = args[2]
|
||
sid = args[3] if len(args) == 4 else None
|
||
message_body = {'prompt_id': prompt_id, 'event': event, 'data': data, 'sid': sid}
|
||
message_id = sen_sqs_msg(message_body, prompt_id)
|
||
logger.info(f'send_sync_proxy message_id :{message_id} message_body: {message_body}')
|
||
logger.debug(f"send_sync_proxy end...")
|
||
|
||
return wrapper
|
||
|
||
|
||
server.PromptServer.send_sync = send_sync_proxy(server.PromptServer.send_sync)
|
||
|
||
|
||
def get_save_imge_path_proxy(func):
|
||
def wrapper(*args, **kwargs):
|
||
logger.info(f"get_save_imge_path_proxy args : {args} kwargs : {kwargs}")
|
||
full_output_folder, filename, counter, subfolder, filename_prefix = func(*args, **kwargs)
|
||
global prompt_id
|
||
filename_prefix_new = filename_prefix + "_" + str(prompt_id)
|
||
logger.info(f"get_save_imge_path_proxy filename_prefix new : {filename_prefix_new}")
|
||
return full_output_folder, filename, counter, subfolder, filename_prefix_new
|
||
|
||
return wrapper
|
||
|
||
|
||
folder_paths.get_save_image_path = get_save_imge_path_proxy(folder_paths.get_save_image_path)
|