464 lines
16 KiB
Python
464 lines
16 KiB
Python
import asyncio
|
|
import datetime
|
|
import logging
|
|
import os
|
|
import socket
|
|
import subprocess
|
|
import sys
|
|
import threading
|
|
import time
|
|
from multiprocessing import Process
|
|
from threading import Lock
|
|
|
|
import boto3
|
|
import httpx
|
|
import requests
|
|
import uvicorn
|
|
from fastapi import APIRouter, FastAPI, Request, HTTPException
|
|
|
|
TIMEOUT_KEEP_ALIVE = 30
|
|
SAGEMAKER_PORT = 8080
|
|
LOCALHOST = '0.0.0.0'
|
|
PHY_LOCALHOST = '127.0.0.1'
|
|
|
|
SLEEP_TIME = 60
|
|
TIME_OUT_TIME = 86400
|
|
MAX_KEEPALIVE_CONNECTIONS = 100
|
|
MAX_CONNECTIONS = 1500
|
|
|
|
app = FastAPI()
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
sagemaker_safe_port_range = os.getenv('SAGEMAKER_SAFE_PORT_RANGE')
|
|
start_port = int(sagemaker_safe_port_range.split('-')[0])
|
|
available_apps = []
|
|
is_multi_gpu = False
|
|
cloudwatch = boto3.client('cloudwatch')
|
|
|
|
endpoint_name = os.getenv('ENDPOINT_NAME')
|
|
endpoint_instance_id = os.getenv('ENDPOINT_INSTANCE_ID', 'default')
|
|
app_cwd = os.getenv('APP_CWD', '/home/ubuntu/ComfyUI')
|
|
|
|
ddb_client = boto3.resource('dynamodb')
|
|
inference_table = ddb_client.Table('ComfyExecuteTable')
|
|
|
|
|
|
class Api:
|
|
def add_api_route(self, path: str, endpoint, **kwargs):
|
|
return self.app.add_api_route(path, endpoint, **kwargs)
|
|
|
|
def __init__(self, app: FastAPI, queue_lock: Lock):
|
|
self.router = APIRouter()
|
|
self.app = app
|
|
self.queue_lock = queue_lock
|
|
self.add_api_route("/invocations", invocations, methods=["POST"])
|
|
self.add_api_route("/ping", ping, methods=["GET"], response_model={})
|
|
|
|
def launch(self, server_name, port):
|
|
self.app.include_router(self.router)
|
|
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
|
|
|
|
|
class ComfyApp:
|
|
def __init__(self, host, port, device_id):
|
|
self.host = host
|
|
self.port = port
|
|
self.device_id = device_id
|
|
self.process = None
|
|
self.busy = False
|
|
self.cwd = app_cwd
|
|
self.name = f"{endpoint_instance_id}-gpus-{device_id}"
|
|
self.stdout_thread = None
|
|
self.stderr_thread = None
|
|
|
|
def _handle_output(self, pipe, _):
|
|
with pipe:
|
|
for line in iter(pipe.readline, ''):
|
|
if line.strip():
|
|
file = f"/tmp/gpu{self.device_id}"
|
|
if os.path.exists(file):
|
|
with open(file, "r") as file:
|
|
cur_prompt_id = file.read().strip()
|
|
if cur_prompt_id:
|
|
sys.stdout.write(f"{self.name}-prompt-{cur_prompt_id}: {line}")
|
|
else:
|
|
sys.stdout.write(f"{self.name}: {line}")
|
|
else:
|
|
sys.stdout.write(f"{self.name}: {line}")
|
|
|
|
def start(self):
|
|
cmd = ["python", "main.py",
|
|
"--listen", self.host,
|
|
"--port", str(self.port),
|
|
"--output-directory", f"{self.cwd}/output/{self.device_id}/",
|
|
"--temp-directory", f"{self.cwd}/temp/{self.device_id}/",
|
|
"--cuda-device", str(self.device_id),
|
|
"--cuda-malloc"
|
|
]
|
|
|
|
logger.info(f"Starting comfy app on {self.port}")
|
|
logger.info(f"Command: {cmd}")
|
|
|
|
self.process = subprocess.Popen(
|
|
cmd,
|
|
cwd=self.cwd,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
text=True
|
|
)
|
|
os.environ['ALREADY_INIT'] = 'true'
|
|
|
|
self.stdout_thread = threading.Thread(target=self._handle_output, args=(self.process.stdout, "STDOUT"))
|
|
self.stderr_thread = threading.Thread(target=self._handle_output, args=(self.process.stderr, "STDERR"))
|
|
|
|
self.stdout_thread.start()
|
|
self.stderr_thread.start()
|
|
|
|
def restart(self):
|
|
logger.info("Comfy app process is going to restart")
|
|
if self.process and self.process.poll() is None:
|
|
os.environ['ALREADY_INIT'] = 'false'
|
|
self.process.terminate()
|
|
self.process.wait()
|
|
self.stdout_thread.join()
|
|
self.stderr_thread.join()
|
|
self.start()
|
|
|
|
def is_port_ready(self):
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
|
sock.settimeout(1)
|
|
result = sock.connect_ex(('127.0.0.1', self.port))
|
|
return result == 0
|
|
|
|
def set_prompt(self, request_obj=None):
|
|
if request_obj and 'prompt_id' in request_obj:
|
|
prompt_id = request_obj['prompt_id']
|
|
else:
|
|
prompt_id = ""
|
|
|
|
logger.debug(f"set_prompt '{prompt_id}' on device {self.device_id}")
|
|
with open(f"/tmp/gpu{self.device_id}", "w") as f:
|
|
f.write(str(prompt_id))
|
|
|
|
|
|
def update_execute_job_table(prompt_id, key, value):
|
|
logger.debug(f"Update job with prompt_id: {prompt_id}, key: {key}, value: {value}")
|
|
try:
|
|
inference_table.update_item(
|
|
Key={
|
|
"prompt_id": prompt_id,
|
|
},
|
|
UpdateExpression=f"set #k = :r",
|
|
ExpressionAttributeNames={'#k': key},
|
|
ExpressionAttributeValues={':r': value},
|
|
ConditionExpression="attribute_exists(prompt_id)",
|
|
ReturnValues="UPDATED_NEW"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Update execute job table error: {e}")
|
|
raise e
|
|
|
|
|
|
async def send_request(request_obj, comfy_app: ComfyApp, need_async: bool):
|
|
try:
|
|
record_metric(comfy_app, request_obj)
|
|
logger.info(f"Starting on {comfy_app.port} {need_async} {request_obj}")
|
|
|
|
comfy_app.busy = True
|
|
comfy_app.set_prompt(request_obj)
|
|
|
|
request_obj['port'] = comfy_app.port
|
|
request_obj['out_path'] = comfy_app.device_id
|
|
|
|
start_time = datetime.datetime.now().isoformat()
|
|
update_execute_job_table(prompt_id=request_obj['prompt_id'], key="start_time", value=start_time)
|
|
|
|
logger.info(f"Invocations start req: {request_obj}, url: {PHY_LOCALHOST}:{comfy_app.port}/execute_proxy")
|
|
# if need_async:
|
|
# async with httpx.AsyncClient(timeout=TIME_OUT_TIME,
|
|
# limits=httpx.Limits(max_keepalive_connections=MAX_KEEPALIVE_CONNECTIONS,
|
|
# max_connections=MAX_CONNECTIONS)) as client:
|
|
# response = await client.post(f"http://{PHY_LOCALHOST}:{comfy_app.port}/execute_proxy", json=request_obj)
|
|
# else:
|
|
# response = requests.post(f"http://{PHY_LOCALHOST}:{comfy_app.port}/execute_proxy", json=request_obj)
|
|
async with httpx.AsyncClient(timeout=TIME_OUT_TIME,
|
|
limits=httpx.Limits(max_keepalive_connections=MAX_KEEPALIVE_CONNECTIONS,
|
|
max_connections=MAX_CONNECTIONS)) as client:
|
|
response = await client.post(f"http://{PHY_LOCALHOST}:{comfy_app.port}/execute_proxy", json=request_obj)
|
|
|
|
comfy_app.busy = False
|
|
comfy_app.set_prompt()
|
|
|
|
if response.status_code != 200:
|
|
raise HTTPException(status_code=response.status_code,
|
|
detail=f"COMFY service returned an error: {response.text}")
|
|
return wrap_response(start_time, response, comfy_app, request_obj)
|
|
except Exception as e:
|
|
logger.error(f"send_request error {e}")
|
|
raise HTTPException(status_code=500, detail=f"COMFY service not available for internal multi reqs {e}")
|
|
finally:
|
|
comfy_app.busy = False
|
|
comfy_app.set_prompt()
|
|
|
|
|
|
async def invocations(request: Request):
|
|
global is_multi_gpu
|
|
try:
|
|
if is_multi_gpu:
|
|
gpu_nums = get_gpu_count()
|
|
logger.info(f"Number of GPUs: {gpu_nums}")
|
|
req = await request.json()
|
|
logger.info(f"Starting multi invocation {req}")
|
|
|
|
tasks = []
|
|
for request_obj in req:
|
|
comfy_app = check_available_app(True)
|
|
max_retries = TIME_OUT_TIME
|
|
while comfy_app is None:
|
|
if max_retries > 0:
|
|
max_retries = max_retries - 60
|
|
time.sleep(60)
|
|
comfy_app = check_available_app(True)
|
|
else:
|
|
raise HTTPException(status_code=500, detail=f"COMFY service not available for multi reqs")
|
|
# if comfy_app is None:
|
|
# raise HTTPException(status_code=500, detail=f"COMFY service not available for multi reqs")
|
|
tasks.append(send_request(request_obj, comfy_app, True))
|
|
logger.info("all tasks completed send, waiting result")
|
|
results = await asyncio.gather(*tasks)
|
|
logger.info(f'Finished invocations {results}')
|
|
return results
|
|
else:
|
|
req = await request.json()
|
|
result = []
|
|
logger.info(f"Starting single invocation request is: {req}")
|
|
for request_obj in req:
|
|
comfy_app = check_available_app(True)
|
|
if comfy_app is None:
|
|
raise HTTPException(status_code=500, detail=f"COMFY service not available for single reqs")
|
|
response = await send_request(request_obj, comfy_app, False)
|
|
result.append(response)
|
|
logger.info(f"Finished invocations result: {result}")
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"invocations error of {e}")
|
|
return []
|
|
|
|
|
|
def ping():
|
|
init_already = os.environ.get('ALREADY_INIT')
|
|
if init_already and init_already.lower() == 'false':
|
|
raise HTTPException(status_code=500)
|
|
|
|
comfy_app = check_available_app(False)
|
|
if comfy_app is None:
|
|
raise HTTPException(status_code=500)
|
|
logger.debug(f"check status start url:{PHY_LOCALHOST}:{comfy_app.port}/queue")
|
|
response = requests.get(f"http://{PHY_LOCALHOST}:{comfy_app.port}/queue")
|
|
if response.status_code != 200:
|
|
raise HTTPException(status_code=500)
|
|
return {'status': 'Healthy'}
|
|
|
|
|
|
def wrap_response(start_time, response, comfy_app: ComfyApp, request_obj):
|
|
data = response.json()
|
|
data['start_time'] = start_time
|
|
data['endpoint_name'] = os.getenv('ENDPOINT_NAME')
|
|
data['endpoint_instance_id'] = os.getenv('ENDPOINT_INSTANCE_ID')
|
|
data['device_id'] = comfy_app.device_id
|
|
|
|
if 'workflow' in request_obj and request_obj['workflow']:
|
|
data['workflow'] = request_obj['workflow']
|
|
|
|
return data
|
|
|
|
|
|
def record_metric(comfy_app: ComfyApp, request_obj):
|
|
data = [
|
|
{
|
|
'MetricName': 'InferenceTotal',
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Endpoint',
|
|
'Value': endpoint_name
|
|
},
|
|
{
|
|
'Name': 'Instance',
|
|
'Value': endpoint_instance_id
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': 1,
|
|
'Unit': 'Count'
|
|
},
|
|
{
|
|
'MetricName': 'InferenceTotal',
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Endpoint',
|
|
'Value': endpoint_name
|
|
},
|
|
{
|
|
'Name': 'Instance',
|
|
'Value': endpoint_instance_id
|
|
},
|
|
{
|
|
'Name': 'InstanceGPU',
|
|
'Value': f"GPU{comfy_app.device_id}"
|
|
}
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': 1,
|
|
'Unit': 'Count'
|
|
},
|
|
{
|
|
'MetricName': 'InferenceEndpointReceived',
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Service',
|
|
'Value': 'Comfy'
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': 1,
|
|
'Unit': 'Count'
|
|
},
|
|
{
|
|
'MetricName': 'InferenceEndpointReceived',
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Endpoint',
|
|
'Value': endpoint_name
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': 1,
|
|
'Unit': 'Count'
|
|
},
|
|
]
|
|
|
|
if 'workflow' in request_obj and request_obj['workflow']:
|
|
data.append({
|
|
'MetricName': 'InferenceEndpointReceived',
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Workflow',
|
|
'Value': request_obj['workflow']
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': 1,
|
|
'Unit': 'Count'
|
|
})
|
|
|
|
response = cloudwatch.put_metric_data(
|
|
Namespace='ESD',
|
|
MetricData=data
|
|
)
|
|
logger.info(f"record_metric response: {response}")
|
|
|
|
|
|
def get_gpu_count():
|
|
try:
|
|
result = subprocess.run(['nvidia-smi', '-L'], capture_output=True, text=True, check=True)
|
|
gpu_count = result.stdout.count('\n')
|
|
return gpu_count
|
|
except subprocess.CalledProcessError as e:
|
|
logger.info("Failed to run nvidia-smi:", e)
|
|
return 0
|
|
except Exception as e:
|
|
logger.info("An error occurred:", e)
|
|
return 0
|
|
|
|
|
|
def start_comfy_servers():
|
|
global is_multi_gpu
|
|
gpu_nums = get_gpu_count()
|
|
if gpu_nums > 1:
|
|
is_multi_gpu = True
|
|
else:
|
|
is_multi_gpu = False
|
|
logger.info(f"is_multi_gpu is {is_multi_gpu}")
|
|
for gpu_num in range(gpu_nums):
|
|
logger.info(f"start comfy server by device_id: {gpu_num}")
|
|
port = start_port + gpu_num
|
|
comfy_app = ComfyApp(host=LOCALHOST, port=port, device_id=gpu_num)
|
|
comfy_app.start()
|
|
available_apps.append(comfy_app)
|
|
|
|
|
|
def get_available_app(need_check_busy: bool):
|
|
global available_apps
|
|
if available_apps is None:
|
|
return None
|
|
for item in available_apps:
|
|
logger.debug(f"get available apps {item.device_id} {item.busy}")
|
|
if need_check_busy:
|
|
if item.is_port_ready() and not item.busy:
|
|
item.busy = True
|
|
return item
|
|
else:
|
|
if item.is_port_ready():
|
|
return item
|
|
return None
|
|
|
|
|
|
def check_available_app(need_check_busy: bool):
|
|
comfy_app = get_available_app(need_check_busy)
|
|
i = 0
|
|
while comfy_app is None:
|
|
comfy_app = get_available_app(need_check_busy)
|
|
if comfy_app is None:
|
|
time.sleep(1)
|
|
i += 1
|
|
if i >= 3:
|
|
logger.info(f"There is no available comfy_app for {i} attempts.")
|
|
break
|
|
if comfy_app is None:
|
|
logger.info(f"There is no available comfy_app! Ignoring this request")
|
|
return None
|
|
return comfy_app
|
|
|
|
|
|
def check_sync():
|
|
logger.debug("start check_sync!")
|
|
while True:
|
|
try:
|
|
comfy_app = check_available_app(False)
|
|
if comfy_app is None:
|
|
raise HTTPException(status_code=500,
|
|
detail=f"COMFY service returned an error: no avaliable app")
|
|
logger.info("start check_sync! checking function-------")
|
|
response = requests.post(f"http://{PHY_LOCALHOST}:{comfy_app.port}/sync_instance")
|
|
logger.info(f"sync response:{response.json()} time : {datetime.datetime.now()}")
|
|
|
|
global available_apps
|
|
for item in available_apps:
|
|
if item and item.port and not item.busy:
|
|
logger.info(f"start check_reboot! {item.port}")
|
|
requests.post(f"http://{PHY_LOCALHOST}:{item.port}/reboot")
|
|
logger.debug(f"reboot response time : {datetime.datetime.now()}")
|
|
else:
|
|
logger.info(f"not start check_reboot! {item}")
|
|
time.sleep(SLEEP_TIME)
|
|
except Exception as e:
|
|
logger.info(f"check_and_reboot error:{e}")
|
|
time.sleep(SLEEP_TIME)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
queue_lock = threading.Lock()
|
|
api = Api(app, queue_lock)
|
|
start_comfy_servers()
|
|
|
|
api_process = Process(target=api.launch, args=(LOCALHOST, SAGEMAKER_PORT))
|
|
check_sync_thread = threading.Thread(target=check_sync)
|
|
|
|
api_process.start()
|
|
check_sync_thread.start()
|
|
|
|
api_process.join()
|