609 lines
26 KiB
Python
609 lines
26 KiB
Python
import base64
|
||
import datetime
|
||
import json
|
||
import logging
|
||
import os
|
||
import sys
|
||
import tarfile
|
||
import time
|
||
import uuid
|
||
import gc
|
||
from dataclasses import dataclass
|
||
from typing import Optional
|
||
|
||
import boto3
|
||
import execution
|
||
import server
|
||
import folder_paths
|
||
from aiohttp import web
|
||
from boto3.dynamodb.conditions import Key
|
||
import comfy
|
||
|
||
global need_sync
|
||
global prompt_id
|
||
global executing
|
||
executing = False
|
||
|
||
global reboot
|
||
reboot = False
|
||
|
||
global last_call_time
|
||
last_call_time = None
|
||
global gc_triggered
|
||
gc_triggered = False
|
||
|
||
REGION = os.environ.get('AWS_REGION')
|
||
BUCKET = os.environ.get('S3_BUCKET_NAME')
|
||
QUEUE_URL = os.environ.get('COMFY_QUEUE_URL')
|
||
|
||
GEN_INSTANCE_ID = os.environ.get('ENDPOINT_INSTANCE_ID') if 'ENDPOINT_INSTANCE_ID' in os.environ and os.environ.get('ENDPOINT_INSTANCE_ID') else str(uuid.uuid4())
|
||
ENDPOINT_NAME = os.environ.get('ENDPOINT_NAME')
|
||
ENDPOINT_ID = os.environ.get('ENDPOINT_ID')
|
||
|
||
INSTANCE_MONITOR_TABLE_NAME = os.environ.get('COMFY_INSTANCE_MONITOR_TABLE')
|
||
SYNC_TABLE_NAME = os.environ.get('COMFY_SYNC_TABLE')
|
||
|
||
dynamodb = boto3.resource('dynamodb', region_name=REGION)
|
||
sync_table = dynamodb.Table(SYNC_TABLE_NAME)
|
||
instance_monitor_table = dynamodb.Table(INSTANCE_MONITOR_TABLE_NAME)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.INFO)
|
||
|
||
ROOT_PATH = '/home/ubuntu/ComfyUI'
|
||
sqs_client = boto3.client('sqs', region_name=REGION)
|
||
|
||
GC_WAIT_TIME = 1800
|
||
|
||
|
||
def print_env():
|
||
for key, value in os.environ.items():
|
||
logger.info(f"{key}: {value}")
|
||
|
||
|
||
@dataclass
|
||
class ComfyResponse:
|
||
statusCode: int
|
||
message: str
|
||
body: Optional[dict]
|
||
|
||
|
||
def ok(body: dict):
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps(body))
|
||
|
||
|
||
def error(body: dict):
|
||
# TODO 500 -》200 because of need resp anyway not exception
|
||
return web.Response(status=200, content_type='application/json', body=json.dumps(body))
|
||
|
||
|
||
def sen_sqs_msg(message_body, prompt_id_key):
|
||
response = sqs_client.send_message(
|
||
QueueUrl=QUEUE_URL,
|
||
MessageBody=json.dumps(message_body),
|
||
MessageGroupId=prompt_id_key
|
||
)
|
||
message_id = response['MessageId']
|
||
return message_id
|
||
|
||
|
||
def sen_finish_sqs_msg(prompt_id_key):
|
||
global need_sync
|
||
# logger.info(f"sen_finish_sqs_msg start... {need_sync},{prompt_id_key}")
|
||
if need_sync and QUEUE_URL and REGION:
|
||
message_body = {'prompt_id': prompt_id_key, 'event': 'finish', 'data': {"node": None, "prompt_id": prompt_id_key},
|
||
'sid': None}
|
||
message_id = sen_sqs_msg(message_body, prompt_id_key)
|
||
logger.info(f"finish message sent {message_id}")
|
||
|
||
|
||
async def prepare_comfy_env(sync_item: dict):
|
||
try:
|
||
request_id = sync_item['request_id']
|
||
logger.info(f"prepare_environment start sync_item:{sync_item}")
|
||
prepare_type = sync_item['prepare_type']
|
||
rlt = True
|
||
if prepare_type in ['default', 'models']:
|
||
sync_models_rlt = sync_s3_files_or_folders_to_local(f'{request_id}/models/*', f'{ROOT_PATH}/models', False)
|
||
if not sync_models_rlt:
|
||
rlt = False
|
||
if prepare_type in ['default', 'inputs']:
|
||
sync_inputs_rlt = sync_s3_files_or_folders_to_local(f'{request_id}/input/*', f'{ROOT_PATH}/input', False)
|
||
if not sync_inputs_rlt:
|
||
rlt = False
|
||
if prepare_type in ['default', 'nodes']:
|
||
sync_nodes_rlt = sync_s3_files_or_folders_to_local(f'{request_id}/custom_nodes/*',
|
||
f'{ROOT_PATH}/custom_nodes', True)
|
||
if not sync_nodes_rlt:
|
||
rlt = False
|
||
if prepare_type == 'custom':
|
||
sync_source_path = sync_item['s3_source_path']
|
||
local_target_path = sync_item['local_target_path']
|
||
if not sync_source_path or not local_target_path:
|
||
logger.info("s3_source_path and local_target_path should not be empty")
|
||
else:
|
||
sync_rlt = sync_s3_files_or_folders_to_local(sync_source_path,
|
||
f'{ROOT_PATH}/{local_target_path}', False)
|
||
if not sync_rlt:
|
||
rlt = False
|
||
elif prepare_type == 'other':
|
||
sync_script = sync_item['sync_script']
|
||
logger.info("sync_script")
|
||
# sync_script.startswith('s5cmd') 不允许
|
||
try:
|
||
if sync_script and (sync_script.startswith("python3 -m pip") or sync_script.startswith("python -m pip")
|
||
or sync_script.startswith("pip install") or sync_script.startswith("apt")
|
||
or sync_script.startswith("os.environ") or sync_script.startswith("ls")
|
||
or sync_script.startswith("env") or sync_script.startswith("source")
|
||
or sync_script.startswith("curl") or sync_script.startswith("wget")
|
||
or sync_script.startswith("print") or sync_script.startswith("cat")
|
||
or sync_script.startswith("sudo chmod") or sync_script.startswith("chmod")
|
||
or sync_script.startswith("/home/ubuntu/ComfyUI/venv/bin/python")):
|
||
os.system(sync_script)
|
||
elif sync_script and (sync_script.startswith("export ") and len(sync_script.split(" ")) > 2):
|
||
sync_script_key = sync_script.split(" ")[1]
|
||
sync_script_value = sync_script.split(" ")[2]
|
||
os.environ[sync_script_key] = sync_script_value
|
||
logger.info(os.environ.get(sync_script_key))
|
||
except Exception as e:
|
||
logger.error(f"Exception while execute sync_scripts : {sync_script}")
|
||
rlt = False
|
||
need_reboot = True if ('need_reboot' in sync_item and sync_item['need_reboot']
|
||
and str(sync_item['need_reboot']).lower() == 'true')else False
|
||
global reboot
|
||
reboot = need_reboot
|
||
if need_reboot:
|
||
os.environ['NEED_REBOOT'] = 'true'
|
||
else:
|
||
os.environ['NEED_REBOOT'] = 'false'
|
||
logger.info("prepare_environment end")
|
||
os.environ['LAST_SYNC_REQUEST_ID'] = sync_item['request_id']
|
||
os.environ['LAST_SYNC_REQUEST_TIME'] = str(sync_item['request_time'])
|
||
return rlt
|
||
except Exception as e:
|
||
return False
|
||
|
||
|
||
def sync_s3_files_or_folders_to_local(s3_path, local_path, need_un_tar):
|
||
logger.info("sync_s3_models_or_inputs_to_local start")
|
||
# s5cmd_command = f'{ROOT_PATH}/tools/s5cmd sync "s3://{bucket_name}/{s3_path}/*" "{local_path}/"'
|
||
if need_un_tar:
|
||
s5cmd_command = f's5cmd sync "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
|
||
else:
|
||
s5cmd_command = f's5cmd sync --delete=true "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
|
||
# s5cmd_command = f's5cmd sync --delete=true "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
|
||
# s5cmd_command = f's5cmd sync "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" "{local_path}/"'
|
||
try:
|
||
logger.info(s5cmd_command)
|
||
os.system(s5cmd_command)
|
||
logger.info(f'Files copied from "s3://{BUCKET}/comfy/{ENDPOINT_NAME}/{s3_path}" to "{local_path}/"')
|
||
if need_un_tar:
|
||
for filename in os.listdir(local_path):
|
||
if filename.endswith(".tar.gz"):
|
||
tar_filepath = os.path.join(local_path, filename)
|
||
# extract_path = os.path.splitext(os.path.splitext(tar_filepath)[0])[0]
|
||
# os.makedirs(extract_path, exist_ok=True)
|
||
# logger.info(f'Extracting extract_path is {extract_path}')
|
||
|
||
with tarfile.open(tar_filepath, "r:gz") as tar:
|
||
for member in tar.getmembers():
|
||
tar.extract(member, path=local_path)
|
||
os.remove(tar_filepath)
|
||
logger.info(f'File {tar_filepath} extracted and removed')
|
||
return True
|
||
except Exception as e:
|
||
logger.info(f"Error executing s5cmd command: {e}")
|
||
return False
|
||
|
||
|
||
def sync_local_outputs_to_s3(s3_path, local_path):
|
||
logger.info("sync_local_outputs_to_s3 start")
|
||
s5cmd_command = f's5cmd sync "{local_path}/*" "s3://{BUCKET}/comfy/{s3_path}/" '
|
||
try:
|
||
logger.info(s5cmd_command)
|
||
os.system(s5cmd_command)
|
||
logger.info(f'Files copied local to "s3://{BUCKET}/comfy/{s3_path}/" to "{local_path}/"')
|
||
clean_cmd = f'rm -rf {local_path}'
|
||
os.system(clean_cmd)
|
||
logger.info(f'Files removed from local {local_path}')
|
||
except Exception as e:
|
||
logger.info(f"Error executing s5cmd command: {e}")
|
||
|
||
|
||
def sync_local_outputs_to_base64(local_path):
|
||
logger.info("sync_local_outputs_to_base64 start")
|
||
try:
|
||
result = {}
|
||
for root, dirs, files in os.walk(local_path):
|
||
for file in files:
|
||
file_path = os.path.join(root, file)
|
||
with open(file_path, "rb") as f:
|
||
file_content = f.read()
|
||
base64_content = base64.b64encode(file_content).decode('utf-8')
|
||
result[file] = base64_content
|
||
clean_cmd = f'rm -rf {local_path}'
|
||
os.system(clean_cmd)
|
||
logger.info(f'Files removed from local {local_path}')
|
||
return result
|
||
except Exception as e:
|
||
logger.info(f"Error executing s5cmd command: {e}")
|
||
return {}
|
||
|
||
|
||
@server.PromptServer.instance.routes.post("/execute_proxy")
|
||
async def execute_proxy(request):
|
||
json_data = await request.json()
|
||
if 'out_path' in json_data and json_data['out_path'] is not None:
|
||
out_path = json_data['out_path']
|
||
else:
|
||
out_path = None
|
||
logger.info(f"invocations start json_data:{json_data}")
|
||
global need_sync
|
||
need_sync = json_data["need_sync"]
|
||
global prompt_id
|
||
prompt_id = json_data["prompt_id"]
|
||
try:
|
||
global executing
|
||
if executing is True:
|
||
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
|
||
"message": "the environment is not ready valid[0] is false, need to resync"}
|
||
sen_finish_sqs_msg(prompt_id)
|
||
return error(resp)
|
||
executing = True
|
||
logger.info(
|
||
f'bucket_name: {BUCKET}, region: {REGION}')
|
||
if ('need_prepare' in json_data and json_data['need_prepare']
|
||
and 'prepare_props' in json_data and json_data['prepare_props']):
|
||
sync_already = await prepare_comfy_env(json_data['prepare_props'])
|
||
if not sync_already:
|
||
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
|
||
"message": "the environment is not ready with sync"}
|
||
executing = False
|
||
sen_finish_sqs_msg(prompt_id)
|
||
return error(resp)
|
||
server_instance = server.PromptServer.instance
|
||
if "number" in json_data:
|
||
number = float(json_data['number'])
|
||
server_instance.number = number
|
||
else:
|
||
number = server_instance.number
|
||
if "front" in json_data:
|
||
if json_data['front']:
|
||
number = -number
|
||
server_instance.number += 1
|
||
valid = execution.validate_prompt(json_data['prompt'])
|
||
logger.info(f"Validating prompt result is {valid}")
|
||
if not valid[0]:
|
||
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
|
||
"message": "the environment is not ready valid[0] is false, need to resync"}
|
||
executing = False
|
||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||
sen_finish_sqs_msg(prompt_id)
|
||
return error(resp)
|
||
# if len(valid) == 4 and len(valid[3]) > 0:
|
||
# logger.info(f"Validating prompt error there is something error because of :valid: {valid}")
|
||
# resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
|
||
# "message": f"the valid is error, need to resync or check the workflow :{valid}"}
|
||
# executing = False
|
||
# return error(resp)
|
||
extra_data = {}
|
||
client_id = ''
|
||
if "extra_data" in json_data:
|
||
extra_data = json_data["extra_data"]
|
||
if 'client_id' in extra_data and extra_data['client_id']:
|
||
client_id = extra_data['client_id']
|
||
if "client_id" in json_data and json_data["client_id"]:
|
||
extra_data["client_id"] = json_data["client_id"]
|
||
client_id = json_data["client_id"]
|
||
|
||
server_instance.client_id = client_id
|
||
|
||
prompt_id = json_data['prompt_id']
|
||
server_instance.last_prompt_id = prompt_id
|
||
e = execution.PromptExecutor(server_instance)
|
||
outputs_to_execute = valid[2]
|
||
e.execute(json_data['prompt'], prompt_id, extra_data, outputs_to_execute)
|
||
|
||
s3_out_path = f'output/{prompt_id}/{out_path}' if out_path is not None else f'output/{prompt_id}'
|
||
s3_temp_path = f'temp/{prompt_id}/{out_path}' if out_path is not None else f'temp/{prompt_id}'
|
||
local_out_path = f'{ROOT_PATH}/output/{out_path}' if out_path is not None else f'{ROOT_PATH}/output'
|
||
local_temp_path = f'{ROOT_PATH}/temp/{out_path}' if out_path is not None else f'{ROOT_PATH}/temp'
|
||
|
||
logger.info(f"s3_out_path is {s3_out_path} and s3_temp_path is {s3_temp_path} and local_out_path is {local_out_path} and local_temp_path is {local_temp_path}")
|
||
|
||
sync_local_outputs_to_s3(s3_out_path, local_out_path)
|
||
sync_local_outputs_to_s3(s3_temp_path, local_temp_path)
|
||
|
||
response_body = {
|
||
"prompt_id": prompt_id,
|
||
"instance_id": GEN_INSTANCE_ID,
|
||
"status": "success",
|
||
"output_path": f's3://{BUCKET}/comfy/{s3_out_path}',
|
||
"temp_path": f's3://{BUCKET}/comfy/{s3_temp_path}',
|
||
}
|
||
sen_finish_sqs_msg(prompt_id)
|
||
logger.info(f"execute inference response is {response_body}")
|
||
executing = False
|
||
return ok(response_body)
|
||
except Exception as ecp:
|
||
logger.info(f"exception occurred {ecp}")
|
||
resp = {"prompt_id": prompt_id, "instance_id": GEN_INSTANCE_ID, "status": "fail",
|
||
"message": f"exception occurred {ecp}"}
|
||
executing = False
|
||
return error(resp)
|
||
finally:
|
||
logger.info(f"gc check: {time.time()}")
|
||
try:
|
||
global last_call_time, gc_triggered
|
||
gc_triggered = False
|
||
if last_call_time is None:
|
||
logger.info(f"gc check last time is NONE")
|
||
last_call_time = time.time()
|
||
else:
|
||
if time.time() - last_call_time > GC_WAIT_TIME:
|
||
if not gc_triggered:
|
||
logger.info(f"gc start: {time.time()} - {last_call_time}")
|
||
e.reset()
|
||
comfy.model_management.cleanup_models()
|
||
gc.collect()
|
||
comfy.model_management.soft_empty_cache()
|
||
gc_triggered = True
|
||
logger.info(f"gc end: {time.time()} - {last_call_time}")
|
||
last_call_time = time.time()
|
||
else:
|
||
last_call_time = time.time()
|
||
logger.info(f"gc check end: {time.time()}")
|
||
except Exception as e:
|
||
logger.info(f"gc error: {e}")
|
||
|
||
|
||
def get_last_ddb_sync_record():
|
||
sync_response = sync_table.query(
|
||
KeyConditionExpression=Key('endpoint_name').eq(ENDPOINT_NAME),
|
||
Limit=1,
|
||
ScanIndexForward=False
|
||
)
|
||
latest_sync_record = sync_response['Items'][0] if ('Items' in sync_response
|
||
and len(sync_response['Items']) > 0) else None
|
||
if latest_sync_record:
|
||
logger.info(f"latest_sync_record is:{latest_sync_record}")
|
||
return latest_sync_record
|
||
|
||
logger.info("no latest_sync_record found")
|
||
return None
|
||
|
||
|
||
def get_latest_ddb_instance_monitor_record():
|
||
key_condition_expression = ('endpoint_name = :endpoint_name_val '
|
||
'AND gen_instance_id = :gen_instance_id_val')
|
||
expression_attribute_values = {
|
||
':endpoint_name_val': ENDPOINT_NAME,
|
||
':gen_instance_id_val': GEN_INSTANCE_ID
|
||
}
|
||
instance_monitor_response = instance_monitor_table.query(
|
||
KeyConditionExpression=key_condition_expression,
|
||
ExpressionAttributeValues=expression_attribute_values
|
||
)
|
||
instance_monitor_record = instance_monitor_response['Items'][0] \
|
||
if ('Items' in instance_monitor_response and len(instance_monitor_response['Items']) > 0) else None
|
||
|
||
if instance_monitor_record:
|
||
logger.info(f"instance_monitor_record is {instance_monitor_record}")
|
||
return instance_monitor_record
|
||
|
||
logger.info("no instance_monitor_record found")
|
||
return None
|
||
|
||
|
||
def save_sync_instance_monitor(last_sync_request_id: str, sync_status: str):
|
||
item = {
|
||
'endpoint_id': ENDPOINT_ID,
|
||
'endpoint_name': ENDPOINT_NAME,
|
||
'gen_instance_id': GEN_INSTANCE_ID,
|
||
'sync_status': sync_status,
|
||
'last_sync_request_id': last_sync_request_id,
|
||
'last_sync_time': datetime.datetime.now().isoformat(),
|
||
'sync_list': [],
|
||
'create_time': datetime.datetime.now().isoformat(),
|
||
'last_heartbeat_time': datetime.datetime.now().isoformat()
|
||
}
|
||
save_resp = instance_monitor_table.put_item(Item=item)
|
||
logger.info(f"save instance item {save_resp}")
|
||
return save_resp
|
||
|
||
|
||
def update_sync_instance_monitor(instance_monitor_record):
|
||
# 更新记录
|
||
update_expression = ("SET sync_status = :new_sync_status, last_sync_request_id = :sync_request_id, "
|
||
"sync_list = :sync_list, last_sync_time = :sync_time, last_heartbeat_time = :heartbeat_time")
|
||
expression_attribute_values = {
|
||
":new_sync_status": instance_monitor_record['sync_status'],
|
||
":sync_request_id": instance_monitor_record['last_sync_request_id'],
|
||
":sync_list": instance_monitor_record['sync_list'],
|
||
":sync_time": datetime.datetime.now().isoformat(),
|
||
":heartbeat_time": datetime.datetime.now().isoformat(),
|
||
}
|
||
|
||
response = instance_monitor_table.update_item(
|
||
Key={'endpoint_name': ENDPOINT_NAME,
|
||
'gen_instance_id': GEN_INSTANCE_ID},
|
||
UpdateExpression=update_expression,
|
||
ExpressionAttributeValues=expression_attribute_values
|
||
)
|
||
logger.info(f"update_sync_instance_monitor :{response}")
|
||
return response
|
||
|
||
|
||
def sync_instance_monitor_status(need_save: bool):
|
||
try:
|
||
logger.info(f"sync_instance_monitor_status {datetime.datetime.now()}")
|
||
if need_save:
|
||
save_sync_instance_monitor('', 'init')
|
||
else:
|
||
update_expression = ("SET last_heartbeat_time = :heartbeat_time")
|
||
expression_attribute_values = {
|
||
":heartbeat_time": datetime.datetime.now().isoformat(),
|
||
}
|
||
instance_monitor_table.update_item(
|
||
Key={'endpoint_name': ENDPOINT_NAME,
|
||
'gen_instance_id': GEN_INSTANCE_ID},
|
||
UpdateExpression=update_expression,
|
||
ExpressionAttributeValues=expression_attribute_values
|
||
)
|
||
except Exception as e:
|
||
logger.info(f"sync_instance_monitor_status error :{e}")
|
||
|
||
|
||
@server.PromptServer.instance.routes.post("/reboot")
|
||
async def restart(self):
|
||
logger.debug(f"start to reboot!!!!!!!! {self}")
|
||
global executing
|
||
if executing is True:
|
||
logger.info(f"other inference doing cannot reboot!!!!!!!!")
|
||
return ok({"message": "other inference doing cannot reboot"})
|
||
need_reboot = os.environ.get('NEED_REBOOT')
|
||
if need_reboot and need_reboot.lower() != 'true':
|
||
logger.info("no need to reboot by os")
|
||
return ok({"message": "no need to reboot by os"})
|
||
global reboot
|
||
if reboot is False:
|
||
logger.info("no need to reboot by global constant")
|
||
return ok({"message": "no need to reboot by constant"})
|
||
|
||
logger.debug("rebooting !!!!!!!!")
|
||
try:
|
||
sys.stdout.close_log()
|
||
except Exception as e:
|
||
logger.info(f"error reboot!!!!!!!! {e}")
|
||
pass
|
||
return os.execv(sys.executable, [sys.executable] + sys.argv)
|
||
|
||
|
||
# must be sync invoke and use the env to check
|
||
@server.PromptServer.instance.routes.post("/sync_instance")
|
||
async def sync_instance(request):
|
||
if not BUCKET:
|
||
logger.error("No bucket provided ,wait and try again")
|
||
resp = {"status": "success", "message": "syncing"}
|
||
return ok(resp)
|
||
|
||
if 'ALREADY_SYNC' in os.environ and os.environ.get('ALREADY_SYNC').lower() == 'false':
|
||
resp = {"status": "success", "message": "syncing"}
|
||
logger.error("other process doing ,wait and try again")
|
||
return ok(resp)
|
||
|
||
os.environ['ALREADY_SYNC'] = 'false'
|
||
logger.info(f"sync_instance start !! {datetime.datetime.now().isoformat()} {request}")
|
||
try:
|
||
last_sync_record = get_last_ddb_sync_record()
|
||
if not last_sync_record:
|
||
logger.info("no last sync record found do not need sync")
|
||
sync_instance_monitor_status(True)
|
||
resp = {"status": "success", "message": "no sync"}
|
||
os.environ['ALREADY_SYNC'] = 'true'
|
||
return ok(resp)
|
||
|
||
if ('request_id' in last_sync_record and last_sync_record['request_id']
|
||
and os.environ.get('LAST_SYNC_REQUEST_ID')
|
||
and os.environ.get('LAST_SYNC_REQUEST_ID') == last_sync_record['request_id']
|
||
and os.environ.get('LAST_SYNC_REQUEST_TIME')
|
||
and os.environ.get('LAST_SYNC_REQUEST_TIME') == str(last_sync_record['request_time'])):
|
||
logger.info("last sync record already sync by os check")
|
||
sync_instance_monitor_status(False)
|
||
resp = {"status": "success", "message": "no sync env"}
|
||
os.environ['ALREADY_SYNC'] = 'true'
|
||
return ok(resp)
|
||
|
||
instance_monitor_record = get_latest_ddb_instance_monitor_record()
|
||
if not instance_monitor_record:
|
||
sync_already = await prepare_comfy_env(last_sync_record)
|
||
if sync_already:
|
||
logger.info("should init prepare instance_monitor_record")
|
||
sync_status = 'success' if sync_already else 'failed'
|
||
save_sync_instance_monitor(last_sync_record['request_id'], sync_status)
|
||
else:
|
||
sync_instance_monitor_status(False)
|
||
else:
|
||
if ('last_sync_request_id' in instance_monitor_record
|
||
and instance_monitor_record['last_sync_request_id']
|
||
and instance_monitor_record['last_sync_request_id'] == last_sync_record['request_id']
|
||
and instance_monitor_record['sync_status']
|
||
and instance_monitor_record['sync_status'] == 'success'
|
||
and os.environ.get('LAST_SYNC_REQUEST_TIME')
|
||
and os.environ.get('LAST_SYNC_REQUEST_TIME') == str(last_sync_record['request_time'])):
|
||
logger.info("last sync record already sync")
|
||
sync_instance_monitor_status(False)
|
||
resp = {"status": "success", "message": "no sync ddb"}
|
||
os.environ['ALREADY_SYNC'] = 'true'
|
||
return ok(resp)
|
||
|
||
sync_already = await prepare_comfy_env(last_sync_record)
|
||
instance_monitor_record['sync_status'] = 'success' if sync_already else 'failed'
|
||
instance_monitor_record['last_sync_request_id'] = last_sync_record['request_id']
|
||
sync_list = instance_monitor_record['sync_list'] if ('sync_list' in instance_monitor_record
|
||
and instance_monitor_record['sync_list']) else []
|
||
sync_list.append(last_sync_record['request_id'])
|
||
|
||
instance_monitor_record['sync_list'] = sync_list
|
||
logger.info("should update prepare instance_monitor_record")
|
||
update_sync_instance_monitor(instance_monitor_record)
|
||
os.environ['ALREADY_SYNC'] = 'true'
|
||
resp = {"status": "success", "message": "sync"}
|
||
return ok(resp)
|
||
except Exception as e:
|
||
logger.info("exception occurred", e)
|
||
os.environ['ALREADY_SYNC'] = 'true'
|
||
resp = {"status": "fail", "message": "sync"}
|
||
return error(resp)
|
||
|
||
|
||
def validate_prompt_proxy(func):
|
||
def wrapper(*args, **kwargs):
|
||
logger.info("validate_prompt_proxy start...")
|
||
result = func(*args, **kwargs)
|
||
logger.info("validate_prompt_proxy end...")
|
||
return result
|
||
|
||
return wrapper
|
||
|
||
|
||
execution.validate_prompt = validate_prompt_proxy(execution.validate_prompt)
|
||
|
||
|
||
def send_sync_proxy(func):
|
||
def wrapper(*args, **kwargs):
|
||
logger.debug(f"Sending sync request!!!!!!! {args}")
|
||
global need_sync
|
||
global prompt_id
|
||
logger.info(f"send_sync_proxy start... {need_sync},{prompt_id} {args}")
|
||
func(*args, **kwargs)
|
||
if need_sync and QUEUE_URL and REGION:
|
||
logger.debug(f"send_sync_proxy params... {QUEUE_URL},{REGION},{need_sync},{prompt_id}")
|
||
event = args[1]
|
||
data = args[2]
|
||
sid = args[3] if len(args) == 4 else None
|
||
message_body = {'prompt_id': prompt_id, 'event': event, 'data': data, 'sid': sid}
|
||
message_id = sen_sqs_msg(message_body, prompt_id)
|
||
logger.info(f'send_sync_proxy message_id :{message_id} message_body: {message_body}')
|
||
logger.debug(f"send_sync_proxy end...")
|
||
|
||
return wrapper
|
||
|
||
|
||
server.PromptServer.send_sync = send_sync_proxy(server.PromptServer.send_sync)
|
||
|
||
|
||
def get_save_imge_path_proxy(func):
|
||
def wrapper(*args, **kwargs):
|
||
logger.info(f"get_save_imge_path_proxy args : {args} kwargs : {kwargs}")
|
||
full_output_folder, filename, counter, subfolder, filename_prefix = func(*args, **kwargs)
|
||
global prompt_id
|
||
filename_prefix_new = filename_prefix + "_" + str(prompt_id)
|
||
logger.info(f"get_save_imge_path_proxy filename_prefix new : {filename_prefix_new}")
|
||
return full_output_folder, filename, counter, subfolder, filename_prefix_new
|
||
|
||
return wrapper
|
||
|
||
|
||
folder_paths.get_save_image_path = get_save_imge_path_proxy(folder_paths.get_save_image_path)
|