337 lines
9.7 KiB
Python
Executable File
337 lines
9.7 KiB
Python
Executable File
import asyncio
|
|
import datetime
|
|
import json
|
|
import logging
|
|
import os
|
|
import signal
|
|
import socket
|
|
import subprocess
|
|
import sys
|
|
import threading
|
|
import time
|
|
from typing import List
|
|
|
|
import aiohttp
|
|
import boto3
|
|
import requests
|
|
import uvicorn
|
|
from fastapi import FastAPI, Request
|
|
from fastapi import Response, status
|
|
|
|
sagemaker = boto3.client('sagemaker')
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("Controller")
|
|
logger.setLevel(logging.INFO)
|
|
app = FastAPI()
|
|
SLEEP_TIME = 30
|
|
service_type = os.getenv('SERVICE_TYPE', 'sd')
|
|
endpoint_name = os.getenv('ENDPOINT_NAME')
|
|
sagemaker_safe_port_range = os.getenv('SAGEMAKER_SAFE_PORT_RANGE')
|
|
start_port = int(sagemaker_safe_port_range.split('-')[0])
|
|
should_exit = 0
|
|
|
|
|
|
class App:
|
|
def __init__(self, device_id):
|
|
self.host = "127.0.0.1"
|
|
self.device_id = device_id
|
|
self.port = start_port + device_id
|
|
self.name = f"{service_type}-gpu{device_id}"
|
|
self.process = None
|
|
self.busy = False
|
|
self.stdout_thread = None
|
|
self.stderr_thread = None
|
|
self.cmd = None
|
|
self.cwd = None
|
|
|
|
def start(self):
|
|
self.cwd = '/home/ubuntu/stable-diffusion-webui'
|
|
self.cmd = [
|
|
"python", "launch.py",
|
|
"--listen",
|
|
"--port", str(self.port),
|
|
"--device-id", str(self.device_id),
|
|
"--enable-insecure-extension-access",
|
|
"--api",
|
|
"--api-log",
|
|
"--log-startup",
|
|
"--xformers",
|
|
"--no-half-vae",
|
|
"--no-download-sd-model",
|
|
"--no-hashing",
|
|
"--nowebui",
|
|
"--skip-torch-cuda-test",
|
|
"--skip-load-model-at-start",
|
|
"--disable-safe-unpickle",
|
|
"--skip-prepare-environment",
|
|
"--skip-python-version-check",
|
|
"--skip-install",
|
|
"--skip-version-check",
|
|
"--disable-nan-check",
|
|
]
|
|
|
|
if service_type == 'comfy':
|
|
self.cwd = '/home/ubuntu/ComfyUI'
|
|
self.cmd = [
|
|
"python", "main.py",
|
|
"--listen", self.host,
|
|
"--port", str(self.port),
|
|
"--output-directory", f"/home/ubuntu/ComfyUI/output/{self.device_id}/",
|
|
"--temp-directory", f"/home/ubuntu/ComfyUI/temp/{self.device_id}/",
|
|
"--cuda-device", str(self.device_id),
|
|
]
|
|
|
|
logger.info("Launching app on device %s, port: %s, command: %s", self.device_id, self.port, self.cmd)
|
|
|
|
self.process = subprocess.Popen(
|
|
self.cmd,
|
|
cwd=self.cwd,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
text=True
|
|
)
|
|
|
|
self.stdout_thread = threading.Thread(target=self._handle_output, args=(self.process.stdout, "STDOUT"))
|
|
self.stderr_thread = threading.Thread(target=self._handle_output, args=(self.process.stderr, "STDERR"))
|
|
|
|
self.stdout_thread.start()
|
|
self.stderr_thread.start()
|
|
|
|
def _handle_output(self, pipe, _):
|
|
with pipe:
|
|
for line in iter(pipe.readline, ''):
|
|
if line.strip():
|
|
sys.stdout.write(f"{self.name}: {line}")
|
|
|
|
def stop(self):
|
|
if self.process:
|
|
self.process.terminate()
|
|
self.process.wait()
|
|
self.stdout_thread.join()
|
|
self.stderr_thread.join()
|
|
|
|
def __del__(self):
|
|
self.stop()
|
|
|
|
def restart(self):
|
|
logger.info("app process is going to restart")
|
|
self.stop()
|
|
self.start()
|
|
|
|
def is_port_ready(self):
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
|
sock.settimeout(1)
|
|
result = sock.connect_ex(('127.0.0.1', self.port))
|
|
return result == 0
|
|
|
|
async def invocations(self, payload, infer_id=None):
|
|
|
|
try:
|
|
self.busy = True
|
|
self.name = f"{service_type}-gpu{self.device_id}-{infer_id}"
|
|
|
|
payload['port'] = self.port
|
|
payload['out_path'] = self.device_id
|
|
|
|
url = f"http://127.0.0.1:{self.port}/invocations"
|
|
timeout = aiohttp.ClientTimeout(total=300)
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(url, json=payload, timeout=timeout) as response:
|
|
if response.status != 200:
|
|
result = json.dumps({
|
|
"status_code": response.status,
|
|
"detail": f"service returned an error: {await response.text()}"
|
|
})
|
|
self.busy = False
|
|
return result
|
|
response_data = await response.json()
|
|
|
|
self.busy = False
|
|
self.name = f"{service_type}-gpu{self.device_id}"
|
|
|
|
return response_data
|
|
except Exception as e:
|
|
self.busy = False
|
|
logger.error(f"invocations error:{e}")
|
|
return json.dumps({
|
|
"status_code": 500,
|
|
"detail": f"service returned an error: {str(e)}"
|
|
})
|
|
|
|
|
|
apps: List[App] = []
|
|
|
|
|
|
def get_gpu_count():
|
|
try:
|
|
result = subprocess.run(['nvidia-smi', '-L'], capture_output=True, text=True, check=True)
|
|
gpu_count = result.stdout.count('\n')
|
|
return gpu_count
|
|
except subprocess.CalledProcessError as e:
|
|
print("Failed to run nvidia-smi:", e)
|
|
return 0
|
|
except Exception as e:
|
|
print("An error occurred:", e)
|
|
return 0
|
|
|
|
|
|
def signal_handler(signum, frame):
|
|
logger.info(f"Received signal {signum} ({signal.strsignal(signum)})")
|
|
if signum in [signal.SIGINT, signal.SIGTERM]:
|
|
global should_exit
|
|
should_exit = 1
|
|
sys.exit(0)
|
|
|
|
|
|
def setup_signal_handlers():
|
|
catchable_sigs = set(signal.Signals) - {signal.SIGKILL, signal.SIGSTOP}
|
|
for sig in catchable_sigs:
|
|
try:
|
|
signal.signal(sig, signal_handler)
|
|
except Exception as exc:
|
|
logger.info(f"Signal {sig} cannot be caught")
|
|
|
|
|
|
def get_poll_app():
|
|
for sd_app in apps:
|
|
if sd_app.process and sd_app.process.poll() is None:
|
|
return sd_app
|
|
return None
|
|
|
|
|
|
def get_all_available_apps():
|
|
list: List[App] = []
|
|
for app in apps:
|
|
if app.is_port_ready() and not app.busy:
|
|
list.append(app)
|
|
|
|
return list
|
|
|
|
|
|
def get_available_app():
|
|
apps = get_all_available_apps()
|
|
|
|
if apps:
|
|
return apps[0]
|
|
|
|
return None
|
|
|
|
|
|
def start_apps(nums: int):
|
|
logger.info(f"GPU count: {nums}")
|
|
logger.info(f"Safe start port: {start_port}")
|
|
for device_id in range(nums):
|
|
sd_app = App(device_id)
|
|
sd_app.start()
|
|
apps.append(sd_app)
|
|
|
|
|
|
def check_sync():
|
|
logger.info("start check_sync!")
|
|
while True:
|
|
try:
|
|
app = get_available_app()
|
|
if app:
|
|
logger.info("start check_sync! checking function-------")
|
|
response = requests.post(f"http://127.0.0.1:{app.port}/sync_instance")
|
|
logger.info(f"sync response:{response.json()} time : {datetime.datetime.now()}")
|
|
|
|
logger.info("start check_reboot! checking function-------")
|
|
response2 = requests.post(f"http://127.0.0.1:{app.port}/reboot")
|
|
logger.info(f"reboot response:{response.json()} time : {datetime.datetime.now()}")
|
|
time.sleep(SLEEP_TIME)
|
|
except Exception as e:
|
|
logger.info(f"check_sync error:{e}")
|
|
time.sleep(SLEEP_TIME)
|
|
|
|
|
|
def check_apps():
|
|
logger.info("start check apps!")
|
|
while True:
|
|
time.sleep(SLEEP_TIME)
|
|
if should_exit:
|
|
return
|
|
try:
|
|
logger.info(f"all_apps: {len(apps)} all_available_apps: {len(get_all_available_apps())}")
|
|
except Exception as e:
|
|
logger.info(f"check_and_reboot error:{e}")
|
|
time.sleep(SLEEP_TIME)
|
|
|
|
|
|
@app.get("/ping")
|
|
async def ping():
|
|
global should_exit
|
|
if should_exit:
|
|
await asyncio.sleep(1800)
|
|
return Response(content="pong", status_code=status.HTTP_502_BAD_GATEWAY)
|
|
return {"message": "pong"}
|
|
|
|
|
|
@app.post("/invocations")
|
|
async def invocations(request: Request):
|
|
payload = await request.json()
|
|
|
|
if service_type == 'sd':
|
|
infer_id = payload['id']
|
|
else:
|
|
infer_id = payload['prompt_id']
|
|
|
|
logger.info(f"controller_invocation {infer_id} received")
|
|
|
|
while True:
|
|
app = get_available_app()
|
|
if app:
|
|
return await app.invocations(payload=payload, infer_id=infer_id)
|
|
else:
|
|
await asyncio.sleep(1)
|
|
logger.info(f'controller_invocation {infer_id} waiting for an available app...')
|
|
|
|
|
|
def stop():
|
|
global should_exit
|
|
should_exit = 1
|
|
|
|
logger.info("stopping...")
|
|
for cur_app in apps:
|
|
cur_app.stop()
|
|
|
|
|
|
def check_endpoint():
|
|
while True:
|
|
time.sleep(10)
|
|
if should_exit:
|
|
return
|
|
try:
|
|
sagemaker.describe_endpoint(EndpointName=endpoint_name)
|
|
except Exception as e:
|
|
if 'Could not find endpoint' in str(e):
|
|
logger.info(f"Endpoint {endpoint_name} not found, stopping...")
|
|
stop()
|
|
|
|
|
|
def run_server():
|
|
uvicorn.run(app, host="0.0.0.0", port=8080, log_level="info")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
setup_signal_handlers()
|
|
|
|
server = threading.Thread(target=run_server)
|
|
server.start()
|
|
|
|
gpu_nums = get_gpu_count()
|
|
start_apps(gpu_nums)
|
|
|
|
check_apps_thread = threading.Thread(target=check_apps, daemon=True)
|
|
check_apps_thread.start()
|
|
|
|
check_endpoint_thread = threading.Thread(target=check_endpoint, daemon=True)
|
|
check_endpoint_thread.start()
|
|
|
|
if service_type == 'comfy':
|
|
queue_lock = threading.Lock()
|
|
check_sync_thread = threading.Thread(target=check_sync, daemon=True)
|
|
check_sync_thread.start()
|