From 3e6951cf02fc3a606cf17466b2a47c91ec1b5db8 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 8 Jun 2023 09:38:10 -0500 Subject: [PATCH] prevent some ui buttons from blowing up if utils tab is selected first. dedupe --- scripts/extension.py | 55 ++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/scripts/extension.py b/scripts/extension.py index 4d2bd8a..920a7c1 100644 --- a/scripts/extension.py +++ b/scripts/extension.py @@ -88,23 +88,30 @@ class Script(scripts.Script): interrupt_all_btn.style(full_width=False) interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[]) - # redo benchmarks button redo_benchmarks_btn = gradio.Button(value='Redo benchmarks', variant='stop') redo_benchmarks_btn.style(full_width=False) redo_benchmarks_btn.click(Script.ui_connect_benchmark_button, inputs=[], outputs=[]) - return + @staticmethod + def ensure_init(): + try: + _initialized = Script.world.initialized + except AttributeError: + logger.debug("Distributed system not initialized") + Script.initialize(initial_payload=None) + @staticmethod def ui_connect_benchmark_button(): + Script.ensure_init() logger.info("Redoing benchmarks...") Script.world.benchmark(rebenchmark=True) @staticmethod def user_sync_script(): user_scripts = Path(os.path.abspath(__file__)).parent.joinpath('user') - # user_script = user_scripts.joinpath('example.sh') + for file in user_scripts.iterdir(): if file.is_file() and file.name.startswith('sync'): user_script = file @@ -124,44 +131,33 @@ class Script(scripts.Script): return False - # World is not constructed until the first generation job, so I use an intermediary call @staticmethod def ui_connect_interrupt_btn(): - try: - Script.world.interrupt_remotes() - except AttributeError: - logger.debug("Nothing to interrupt, Distributed system not initialized") + Script.ensure_init() + Script.world.interrupt_remotes() @staticmethod def ui_connect_refresh_ckpts_btn(): - try: - Script.world.refresh_checkpoints() - except AttributeError: - logger.debug("Distributed system not initialized") + Script.ensure_init() + Script.world.refresh_checkpoints() @staticmethod def ui_connect_status(): - try: - worker_status = '' + Script.ensure_init() + worker_status = '' - for worker in Script.world.workers: - if worker.master: - continue + for worker in Script.world.workers: + if worker.master: + continue - worker_status += f"{worker.uuid} at {worker.address} is {worker.state.name}\n" + worker_status += f"{worker.uuid} at {worker.address} is {worker.state.name}\n" - # TODO replace this with a single check to a state flag that we should make in the world class - for worker in Script.world.workers: - if worker.state == State.WORKING: - return Script.world.__str__(), worker_status + # TODO replace this with a single check to a state flag that we should make in the world class + for worker in Script.world.workers: + if worker.state == State.WORKING: + return Script.world.__str__(), worker_status - return 'No active jobs!', worker_status - - # init system if it isn't already - except AttributeError as e: - # batch size will be clobbered later once an actual request is made anyway - Script.initialize(initial_payload=None) - return Script.ui_connect_status() + return 'No active jobs!', worker_status @staticmethod @@ -371,7 +367,6 @@ class Script(scripts.Script): name = re.sub(r'\s?\[[^\]]*\]$', '', opts.data["sd_model_checkpoint"]) vae = opts.data["sd_vae"] option_payload = { - # "sd_model_checkpoint": opts.data["sd_model_checkpoint"], "sd_model_checkpoint": name, "sd_vae": vae }