#!/usr/bin/env python #pylint: disable=redefined-outer-name """ helper methods that creates HTTP session with managed connection pool provides async HTTP get/post methods and several helper methods """ import os import sys import ssl import asyncio import logging import aiohttp import requests import urllib3 from util import Map, log sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860") # automatic1111 api url root sd_username = os.environ.get('SDAPI_USR', None) sd_password = os.environ.get('SDAPI_PWD', None) use_session = True urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) ssl.create_default_context = ssl._create_unverified_context # pylint: disable=protected-access timeout = aiohttp.ClientTimeout(total = None, sock_connect = 10, sock_read = None) # default value is 5 minutes, we need longer for training sess = None quiet = False BaseThreadPolicy = asyncio.WindowsSelectorEventLoopPolicy if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy") else asyncio.DefaultEventLoopPolicy class AnyThreadEventLoopPolicy(BaseThreadPolicy): def get_event_loop(self) -> asyncio.AbstractEventLoop: try: return super().get_event_loop() except (RuntimeError, AssertionError): loop = self.new_event_loop() self.set_event_loop(loop) return loop asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) def authsync(): if sd_username is not None and sd_password is not None: return requests.auth.HTTPBasicAuth(sd_username, sd_password) return None def auth(): if sd_username is not None and sd_password is not None: return aiohttp.BasicAuth(sd_username, sd_password) return None async def result(req): if req.status != 200: if not quiet: log.error({ 'request error': req.status, 'reason': req.reason, 'url': req.url }) if not use_session and sess is not None: await sess.close() return Map({ 'error': req.status, 'reason': req.reason, 'url': req.url }) else: json = await req.json() if isinstance(json, list): res = json elif json is None: res = {} else: res = Map(json) log.debug({ 'request': req.status, 'url': req.url, 'reason': req.reason }) return res def resultsync(req: requests.Response): if req.status_code != 200: if not quiet: log.error({ 'request error': req.status_code, 'reason': req.reason, 'url': req.url }) return Map({ 'error': req.status_code, 'reason': req.reason, 'url': req.url }) else: json = req.json() if isinstance(json, list): res = json elif json is None: res = {} else: res = Map(json) log.debug({ 'request': req.status_code, 'url': req.url, 'reason': req.reason }) return res async def get(endpoint: str, json: dict = None): global sess # pylint: disable=global-statement sess = sess if sess is not None else await session() try: async with sess.get(url=endpoint, json=json, verify_ssl=False) as req: res = await result(req) return res except Exception as err: log.error({ 'session': err }) return {} def getsync(endpoint: str, json: dict = None): try: req = requests.get(f'{sd_url}{endpoint}', json=json, verify=False, auth=authsync()) # pylint: disable=missing-timeout res = resultsync(req) return res except Exception as err: log.error({ 'session': err }) return {} async def post(endpoint: str, json: dict = None): global sess # pylint: disable=global-statement # sess = sess if sess is not None else await session() if sess and not sess.closed: await sess.close() sess = await session() try: async with sess.post(url=endpoint, json=json, verify_ssl=False) as req: res = await result(req) return res except Exception as err: log.error({ 'session': err }) return {} def postsync(endpoint: str, json: dict = None): req = requests.post(f'{sd_url}{endpoint}', json=json, verify=False, auth=authsync()) # pylint: disable=missing-timeout res = resultsync(req) return res async def interrupt(): res = await get('/sdapi/v1/progress?skip_current_image=true') if 'state' in res and res.state.job_count > 0: log.debug({ 'interrupt': res.state }) res = await post('/sdapi/v1/interrupt') await asyncio.sleep(1) return res else: log.debug({ 'interrupt': 'idle' }) return { 'interrupt': 'idle' } def interruptsync(): res = getsync('/sdapi/v1/progress?skip_current_image=true') if 'state' in res and res.state.job_count > 0: log.debug({ 'interrupt': res.state }) res = postsync('/sdapi/v1/interrupt') return res else: log.debug({ 'interrupt': 'idle' }) return { 'interrupt': 'idle' } async def progress(): res = await get('/sdapi/v1/progress?skip_current_image=true') log.debug({ 'progress': res }) return res def progresssync(): res = getsync('/sdapi/v1/progress?skip_current_image=true') log.debug({ 'progress': res }) return res def get_log(): res = getsync('/sdapi/v1/log') for line in res: log.debug(line) return res def get_info(): import time t0 = time.time() res = getsync('/sdapi/v1/system-info/status?full=true&refresh=true') t1 = time.time() print({ 'duration': 1000 * round(t1-t0, 3), **res }) return res def options(): opts = getsync('/sdapi/v1/options') flags = getsync('/sdapi/v1/cmd-flags') return { 'options': opts, 'flags': flags } def shutdown(): try: postsync('/sdapi/v1/shutdown') except Exception as e: log.info({ 'shutdown': e }) async def session(): global sess # pylint: disable=global-statement time = aiohttp.ClientTimeout(total = None, sock_connect = 10, sock_read = None) # default value is 5 minutes, we need longer for training sess = aiohttp.ClientSession(timeout = time, base_url = sd_url, auth=auth()) log.debug({ 'sdapi': 'session created', 'endpoint': sd_url }) """ sess = await aiohttp.ClientSession(timeout = timeout).__aenter__() try: async with sess.get(url = f'{sd_url}/') as req: log.debug({ 'sdapi': 'session created', 'endpoint': sd_url }) except Exception as e: log.error({ 'sdapi': e }) await asyncio.sleep(0) await sess.__aexit__(None, None, None) sess = None return sess """ return sess async def close(): if sess is not None: await asyncio.sleep(0) await sess.close() await sess.__aexit__(None, None, None) log.debug({ 'sdapi': 'session closed', 'endpoint': sd_url }) if __name__ == "__main__": log.setLevel(logging.DEBUG) if 'interrupt' in sys.argv: asyncio.run(interrupt()) if 'progress' in sys.argv: asyncio.run(progress()) if 'progresssync' in sys.argv: progresssync() if 'options' in sys.argv: opt = options() log.debug({ 'options' }) import json print(json.dumps(opt['options'], indent = 2)) log.debug({ 'cmd-flags' }) print(json.dumps(opt['flags'], indent = 2)) if 'log' in sys.argv: get_log() if 'info' in sys.argv: get_info() if 'shutdown' in sys.argv: shutdown() asyncio.run(close(), debug=True) asyncio.run(asyncio.sleep(0.5))