stable-diffusion-aws-extension/build_scripts/inference/controller.py

337 lines
9.7 KiB
Python
Executable File

import asyncio
import datetime
import json
import logging
import os
import signal
import socket
import subprocess
import sys
import threading
import time
from typing import List
import aiohttp
import boto3
import requests
import uvicorn
from fastapi import FastAPI, Request
from fastapi import Response, status
sagemaker = boto3.client('sagemaker')
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("Controller")
logger.setLevel(logging.INFO)
app = FastAPI()
SLEEP_TIME = 30
service_type = os.getenv('SERVICE_TYPE', 'sd')
endpoint_name = os.getenv('ENDPOINT_NAME')
sagemaker_safe_port_range = os.getenv('SAGEMAKER_SAFE_PORT_RANGE')
start_port = int(sagemaker_safe_port_range.split('-')[0])
should_exit = 0
class App:
def __init__(self, device_id):
self.host = "127.0.0.1"
self.device_id = device_id
self.port = start_port + device_id
self.name = f"{service_type}-gpu{device_id}"
self.process = None
self.busy = False
self.stdout_thread = None
self.stderr_thread = None
self.cmd = None
self.cwd = None
def start(self):
self.cwd = '/home/ubuntu/stable-diffusion-webui'
self.cmd = [
"python", "launch.py",
"--listen",
"--port", str(self.port),
"--device-id", str(self.device_id),
"--enable-insecure-extension-access",
"--api",
"--api-log",
"--log-startup",
"--xformers",
"--no-half-vae",
"--no-download-sd-model",
"--no-hashing",
"--nowebui",
"--skip-torch-cuda-test",
"--skip-load-model-at-start",
"--disable-safe-unpickle",
"--skip-prepare-environment",
"--skip-python-version-check",
"--skip-install",
"--skip-version-check",
"--disable-nan-check",
]
if service_type == 'comfy':
self.cwd = '/home/ubuntu/ComfyUI'
self.cmd = [
"python", "main.py",
"--listen", self.host,
"--port", str(self.port),
"--output-directory", f"/home/ubuntu/ComfyUI/output/{self.device_id}/",
"--temp-directory", f"/home/ubuntu/ComfyUI/temp/{self.device_id}/",
"--cuda-device", str(self.device_id),
]
logger.info("Launching app on device %s, port: %s, command: %s", self.device_id, self.port, self.cmd)
self.process = subprocess.Popen(
self.cmd,
cwd=self.cwd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
self.stdout_thread = threading.Thread(target=self._handle_output, args=(self.process.stdout, "STDOUT"))
self.stderr_thread = threading.Thread(target=self._handle_output, args=(self.process.stderr, "STDERR"))
self.stdout_thread.start()
self.stderr_thread.start()
def _handle_output(self, pipe, _):
with pipe:
for line in iter(pipe.readline, ''):
if line.strip():
sys.stdout.write(f"{self.name}: {line}")
def stop(self):
if self.process:
self.process.terminate()
self.process.wait()
self.stdout_thread.join()
self.stderr_thread.join()
def __del__(self):
self.stop()
def restart(self):
logger.info("app process is going to restart")
self.stop()
self.start()
def is_port_ready(self):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1)
result = sock.connect_ex(('127.0.0.1', self.port))
return result == 0
async def invocations(self, payload, infer_id=None):
try:
self.busy = True
self.name = f"{service_type}-gpu{self.device_id}-{infer_id}"
payload['port'] = self.port
payload['out_path'] = self.device_id
url = f"http://127.0.0.1:{self.port}/invocations"
timeout = aiohttp.ClientTimeout(total=300)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload, timeout=timeout) as response:
if response.status != 200:
result = json.dumps({
"status_code": response.status,
"detail": f"service returned an error: {await response.text()}"
})
self.busy = False
return result
response_data = await response.json()
self.busy = False
self.name = f"{service_type}-gpu{self.device_id}"
return response_data
except Exception as e:
self.busy = False
logger.error(f"invocations error:{e}")
return json.dumps({
"status_code": 500,
"detail": f"service returned an error: {str(e)}"
})
apps: List[App] = []
def get_gpu_count():
try:
result = subprocess.run(['nvidia-smi', '-L'], capture_output=True, text=True, check=True)
gpu_count = result.stdout.count('\n')
return gpu_count
except subprocess.CalledProcessError as e:
print("Failed to run nvidia-smi:", e)
return 0
except Exception as e:
print("An error occurred:", e)
return 0
def signal_handler(signum, frame):
logger.info(f"Received signal {signum} ({signal.strsignal(signum)})")
if signum in [signal.SIGINT, signal.SIGTERM]:
global should_exit
should_exit = 1
sys.exit(0)
def setup_signal_handlers():
catchable_sigs = set(signal.Signals) - {signal.SIGKILL, signal.SIGSTOP}
for sig in catchable_sigs:
try:
signal.signal(sig, signal_handler)
except Exception as exc:
logger.info(f"Signal {sig} cannot be caught")
def get_poll_app():
for sd_app in apps:
if sd_app.process and sd_app.process.poll() is None:
return sd_app
return None
def get_all_available_apps():
list: List[App] = []
for app in apps:
if app.is_port_ready() and not app.busy:
list.append(app)
return list
def get_available_app():
apps = get_all_available_apps()
if apps:
return apps[0]
return None
def start_apps(nums: int):
logger.info(f"GPU count: {nums}")
logger.info(f"Safe start port: {start_port}")
for device_id in range(nums):
sd_app = App(device_id)
sd_app.start()
apps.append(sd_app)
def check_sync():
logger.info("start check_sync!")
while True:
try:
app = get_available_app()
if app:
logger.info("start check_sync! checking function-------")
response = requests.post(f"http://127.0.0.1:{app.port}/sync_instance")
logger.info(f"sync response:{response.json()} time : {datetime.datetime.now()}")
logger.info("start check_reboot! checking function-------")
response2 = requests.post(f"http://127.0.0.1:{app.port}/reboot")
logger.info(f"reboot response:{response.json()} time : {datetime.datetime.now()}")
time.sleep(SLEEP_TIME)
except Exception as e:
logger.info(f"check_sync error:{e}")
time.sleep(SLEEP_TIME)
def check_apps():
logger.info("start check apps!")
while True:
time.sleep(SLEEP_TIME)
if should_exit:
return
try:
logger.info(f"all_apps: {len(apps)} all_available_apps: {len(get_all_available_apps())}")
except Exception as e:
logger.info(f"check_and_reboot error:{e}")
time.sleep(SLEEP_TIME)
@app.get("/ping")
async def ping():
global should_exit
if should_exit:
await asyncio.sleep(1800)
return Response(content="pong", status_code=status.HTTP_502_BAD_GATEWAY)
return {"message": "pong"}
@app.post("/invocations")
async def invocations(request: Request):
payload = await request.json()
if service_type == 'sd':
infer_id = payload['id']
else:
infer_id = payload['prompt_id']
logger.info(f"controller_invocation {infer_id} received")
while True:
app = get_available_app()
if app:
return await app.invocations(payload=payload, infer_id=infer_id)
else:
await asyncio.sleep(1)
logger.info(f'controller_invocation {infer_id} waiting for an available app...')
def stop():
global should_exit
should_exit = 1
logger.info("stopping...")
for cur_app in apps:
cur_app.stop()
def check_endpoint():
while True:
time.sleep(10)
if should_exit:
return
try:
sagemaker.describe_endpoint(EndpointName=endpoint_name)
except Exception as e:
if 'Could not find endpoint' in str(e):
logger.info(f"Endpoint {endpoint_name} not found, stopping...")
stop()
def run_server():
uvicorn.run(app, host="0.0.0.0", port=8080, log_level="info")
if __name__ == "__main__":
setup_signal_handlers()
server = threading.Thread(target=run_server)
server.start()
gpu_nums = get_gpu_count()
start_apps(gpu_nums)
check_apps_thread = threading.Thread(target=check_apps, daemon=True)
check_apps_thread.start()
check_endpoint_thread = threading.Thread(target=check_endpoint, daemon=True)
check_endpoint_thread.start()
if service_type == 'comfy':
queue_lock = threading.Lock()
check_sync_thread = threading.Thread(target=check_sync, daemon=True)
check_sync_thread.start()