From cd77acbb5564d728a5fa56f994538607dc94067f Mon Sep 17 00:00:00 2001 From: papuSpartan <30642826+papuSpartan@users.noreply.github.com> Date: Sat, 16 Mar 2024 20:18:22 -0500 Subject: [PATCH 01/20] catch and explain when user script is missing --- scripts/spartan/ui.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/scripts/spartan/ui.py b/scripts/spartan/ui.py index 2d1cfbe..f8c9fd2 100644 --- a/scripts/spartan/ui.py +++ b/scripts/spartan/ui.py @@ -25,10 +25,17 @@ class UI: """executes a script placed by the user at /user/sync*""" user_scripts = Path(os.path.abspath(__file__)).parent.parent.joinpath('user') + user_script = None for file in user_scripts.iterdir(): logger.debug(f"found possible script {file.name}") if file.is_file() and file.name.startswith('sync'): user_script = file + if user_script is None: + logger.error( + "couldn't find user script\n" + "script must be placed under /user/ and filename must begin with sync" + ) + return False suffix = user_script.suffix[1:] From 76d9a32871159300587a9c3d74f73a6fd80818f3 Mon Sep 17 00:00:00 2001 From: papuSpartan <30642826+papuSpartan@users.noreply.github.com> Date: Sat, 16 Mar 2024 20:23:14 -0500 Subject: [PATCH 02/20] remove stop variant styling from util buttons --- scripts/spartan/ui.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/spartan/ui.py b/scripts/spartan/ui.py index f8c9fd2..6769042 100644 --- a/scripts/spartan/ui.py +++ b/scripts/spartan/ui.py @@ -247,7 +247,7 @@ class UI: reload_config_btn = gradio.Button(value='πŸ“œ Reload config') reload_config_btn.click(self.world.load_config) - redo_benchmarks_btn = gradio.Button(value='πŸ“Š Redo benchmarks', variant='stop') + redo_benchmarks_btn = gradio.Button(value='πŸ“Š Redo benchmarks') redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[]) run_usr_btn = gradio.Button(value='βš™οΈ Run script') @@ -259,10 +259,10 @@ class UI: reconnect_lost_workers_btn = gradio.Button(value='πŸ”Œ Reconnect workers') reconnect_lost_workers_btn.click(self.world.ping_remotes) - interrupt_all_btn = gradio.Button(value='⏸️ Interrupt all', variant='stop') + interrupt_all_btn = gradio.Button(value='⏸️ Interrupt all') interrupt_all_btn.click(self.world.interrupt_remotes) - restart_workers_btn = gradio.Button(value="πŸ” Restart All", variant='stop') + restart_workers_btn = gradio.Button(value="πŸ” Restart All") restart_workers_btn.click( _js="confirm_restart_workers", fn=lambda confirmed: self.world.restart_all() if confirmed else None, From 608fbcc76fe1812b537b716dc57258f796b8fe0b Mon Sep 17 00:00:00 2001 From: papuSpartan <30642826+papuSpartan@users.noreply.github.com> Date: Sat, 16 Mar 2024 20:49:29 -0500 Subject: [PATCH 03/20] fix verification setting potentially being forgotten --- scripts/spartan/world.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 0646a40..af83a2b 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -631,6 +631,8 @@ class World: label = next(iter(w.keys())) fields = w[label].__dict__ fields['label'] = label + # TODO must be overridden everytime here or later converted to a config file variable at some point + fields['verify_remotes'] = self.verify_remotes self.add_worker(**fields) @@ -679,8 +681,6 @@ class World: if worker.queried and worker.state == State.IDLE: # TODO worker.queried continue - # for now skip/remove scripts that are not "always on" since there is currently no way to run - # them at the same time as distributed supported_scripts = { 'txt2img': [], 'img2img': [] From 0b203e338e0a29fe0406b037d8cbb69fbce43e2d Mon Sep 17 00:00:00 2001 From: papuSpartan <30642826+papuSpartan@users.noreply.github.com> Date: Sat, 16 Mar 2024 20:01:22 -0500 Subject: [PATCH 04/20] default port to prevent confusion --- scripts/spartan/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/spartan/ui.py b/scripts/spartan/ui.py index 6769042..d23f065 100644 --- a/scripts/spartan/ui.py +++ b/scripts/spartan/ui.py @@ -109,7 +109,7 @@ class UI: self.world.add_worker( label=label, address=address, - port=port, + port=port if len(port) > 0 else 7860, tls=tls, state=state ) From e99d3f644b1651041b5ccbea92f040886c614c6d Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 19 Mar 2024 06:56:53 -0500 Subject: [PATCH 05/20] logging --- scripts/distributed.py | 3 ++- scripts/spartan/worker.py | 2 +- scripts/spartan/world.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/scripts/distributed.py b/scripts/distributed.py index 5733edf..015c7f8 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -85,7 +85,6 @@ class Script(scripts.Script): @staticmethod def add_to_gallery(processed, p): """adds generated images to the image gallery after waiting for all workers to finish""" - webui_state.textinfo = "Distributed - injecting images" def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None): image_params: json = response['parameters'] @@ -161,11 +160,13 @@ class Script(scripts.Script): logger.debug(f"Took master {master_elapsed:.2f}s") # wait for response from all workers + webui_state.textinfo = "Distributed - receiving results" for thread in Script.worker_threads: logger.debug(f"waiting for worker thread '{thread.name}'") thread.join() Script.worker_threads.clear() logger.debug("all worker request threads returned") + webui_state.textinfo = "Distributed - injecting images" # some worker which we know has a good response that we can use for generating the grid donor_worker = None diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index 85316f4..9bc1057 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -471,7 +471,7 @@ class Worker: self.response_time = time.time() - start variance = ((eta - self.response_time) / self.response_time) * 100 - logger.debug(f"Worker '{self.label}'s ETA was off by {variance:.2f}%.\n" + logger.debug(f"Worker '{self.label}'s ETA was off by {variance:.2f}%\n" f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n") # if the variance is greater than 500% then we ignore it to prevent variation inflation diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index af83a2b..4070b70 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -530,7 +530,7 @@ class World: logger.debug( f"worker '{job.worker.label}':\n" f"{num_images_compensate} complementary image(s) = {slack_time:.2f}s slack" - f"/ {secs_per_batch_image:.2f}s per requested image" + f" Γ· {secs_per_batch_image:.2f}s per requested image" ) if not job.add_work(payload, batch_size=num_images_compensate): From 2e16d6149b4c5815e5d39a638e8c66b92df2b148 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 19 Mar 2024 19:20:59 -0500 Subject: [PATCH 06/20] compatability fixes for sdwui forge --- scripts/distributed.py | 8 +++-- scripts/spartan/control_net.py | 55 +++++++++++++++++++++++----------- scripts/spartan/world.py | 20 ++++++++----- 3 files changed, 55 insertions(+), 28 deletions(-) diff --git a/scripts/distributed.py b/scripts/distributed.py index 015c7f8..e8160f8 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -178,7 +178,7 @@ class Script(scripts.Script): images: json = job.worker.response["images"] # if we for some reason get more than we asked for if (job.batch_size * p.n_iter) < len(images): - logger.debug(f"Requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}") + logger.debug(f"requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}") if donor_worker is None: donor_worker = job.worker @@ -262,8 +262,9 @@ class Script(scripts.Script): # grab all controlnet units cn_units = [] cn_args = p.script_args[script.args_from:script.args_to] + for cn_arg in cn_args: - if type(cn_arg).__name__ == "UiControlNetUnit": + if "ControlNetUnit" in type(cn_arg).__name__: cn_units.append(cn_arg) logger.debug(f"Detected {len(cn_units)} controlnet unit(s)") @@ -301,8 +302,9 @@ class Script(scripts.Script): # TODO api for some reason returns 200 even if something failed to be set. # for now we may have to make redundant GET requests to check if actually successful... # https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8146 + name = re.sub(r'\s?\[[^]]*]$', '', opts.data["sd_model_checkpoint"]) - vae = opts.data["sd_vae"] + vae = opts.data.get('sd_vae') option_payload = { "sd_model_checkpoint": name, "sd_vae": vae diff --git a/scripts/spartan/control_net.py b/scripts/spartan/control_net.py index 18a240e..a6c70fa 100644 --- a/scripts/spartan/control_net.py +++ b/scripts/spartan/control_net.py @@ -2,6 +2,15 @@ import copy from PIL import Image from modules.api.api import encode_pil_to_base64 from scripts.spartan.shared import logger +import numpy as np + + +def np_to_b64(image: np.ndarray): + pil = Image.fromarray(image) + image_b64 = str(encode_pil_to_base64(pil), 'utf-8') + image_b64 = 'data:image/png;base64,' + image_b64 + + return image_b64 def pack_control_net(cn_units) -> dict: @@ -17,28 +26,40 @@ def pack_control_net(cn_units) -> dict: cn_args = controlnet['controlnet']['args'] for i in range(0, len(cn_units)): - # copy control net unit to payload - cn_args.append(copy.copy(cn_units[i].__dict__)) + if cn_units[i].enabled: + cn_args.append(copy.deepcopy(cn_units[i].__dict__)) + else: + logger.debug(f"controlnet unit {i} is not enabled (ignoring)") + + for i in range(0, len(cn_args)): unit = cn_args[i] - # if unit isn't enabled then don't bother including - if not unit['enabled']: - del unit['input_mode'] - del unit['image'] - logger.debug(f"Controlnet unit {i} is not enabled. Ignoring") - continue - # serialize image - if unit['image'] is not None: - image = unit['image']['image'] - # mask = unit['image']['mask'] - pil = Image.fromarray(image) - image_b64 = encode_pil_to_base64(pil) - image_b64 = str(image_b64, 'utf-8') - unit['input_image'] = image_b64 + image_pair = unit.get('image') + if image_pair is not None: + image_b64 = np_to_b64(image_pair['image']) + unit['input_image'] = image_b64 # mikubill + unit['image'] = image_b64 # forge + if np.all(image_pair['mask'] == 0): + # stand-alone mask from second gradio component + standalone_mask = unit.get('mask_image') + if standalone_mask is not None: + logger.debug(f"found stand-alone mask for controlnet unit {i}") + mask_b64 = np_to_b64(unit['mask_image']['mask']) + unit['mask'] = mask_b64 # mikubill + unit['mask_image'] = mask_b64 # forge + + else: + # mask from singular gradio image component + logger.debug(f"found mask for controlnet unit {i}") + mask_b64 = np_to_b64(image_pair['mask']) + unit['mask'] = mask_b64 # mikubill + unit['mask_image'] = mask_b64 # forge + + # avoid returning duplicate detection maps since master should return the same one + unit['save_detected_map'] = False # remove anything unserializable del unit['input_mode'] - del unit['image'] return controlnet diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 4070b70..59f0b66 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -685,16 +685,20 @@ class World: 'txt2img': [], 'img2img': [] } - script_info = worker.session.get(url=worker.full_url('script-info')).json() - for key in script_info: - name = key.get('name', None) + response = worker.session.get(url=worker.full_url('script-info')) + if response.status_code == 200: + script_info = response.json() + for key in script_info: + name = key.get('name', None) - if name is not None: - is_alwayson = key.get('is_alwayson', False) - is_img2img = key.get('is_img2img', False) - if is_alwayson: - supported_scripts['img2img' if is_img2img else 'txt2img'].append(name) + if name is not None: + is_alwayson = key.get('is_alwayson', False) + is_img2img = key.get('is_img2img', False) + if is_alwayson: + supported_scripts['img2img' if is_img2img else 'txt2img'].append(name) + else: + logger.error(f"failed to query script-info for worker '{worker.label}': {response}") worker.supported_scripts = supported_scripts msg = f"worker '{worker.label}' is online" From c84d7c8a333620106a864a946b0ae3add6076845 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 19 Mar 2024 19:54:47 -0500 Subject: [PATCH 07/20] fully skip workers in invalid states when re-benching --- scripts/spartan/world.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 59f0b66..df6edd1 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -233,7 +233,7 @@ class World: # have every unbenched worker load the same weights before the benchmark for worker in unbenched_workers: - if worker.master or worker.state == State.DISABLED: + if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE): continue sync_thread = Thread(target=worker.load_options, args=(shared.opts.sd_model_checkpoint, shared.opts.sd_vae)) @@ -244,6 +244,10 @@ class World: # benchmark those that haven't been for worker in unbenched_workers: + if worker.state in (State.DISABLED, State.UNAVAILABLE): + logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark") + continue + t = Thread(target=benchmark_wrapped, args=(worker, ), name=f"{worker.label}_benchmark") benchmark_threads.append(t) t.start() From 3a9d87f82138a332af190e20cf2d4e698bbc6e38 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 22 Mar 2024 15:14:26 -0500 Subject: [PATCH 08/20] bench threads -> coro --- scripts/spartan/control_net.py | 7 ++++++ scripts/spartan/ui.py | 2 +- scripts/spartan/worker.py | 11 +++++++-- scripts/spartan/world.py | 45 +++++++++++++++++++++++++--------- 4 files changed, 51 insertions(+), 14 deletions(-) diff --git a/scripts/spartan/control_net.py b/scripts/spartan/control_net.py index a6c70fa..dc55cd9 100644 --- a/scripts/spartan/control_net.py +++ b/scripts/spartan/control_net.py @@ -3,6 +3,7 @@ from PIL import Image from modules.api.api import encode_pil_to_base64 from scripts.spartan.shared import logger import numpy as np +import json def np_to_b64(image: np.ndarray): @@ -62,4 +63,10 @@ def pack_control_net(cn_units) -> dict: # remove anything unserializable del unit['input_mode'] + try: + json.dumps(controlnet) + except Exception as e: + logger.error(f"failed to serialize controlnet\nfirst unit:\n{controlnet['controlnet']['args'][0]}") + return {} + return controlnet diff --git a/scripts/spartan/ui.py b/scripts/spartan/ui.py index d23f065..4c710b6 100644 --- a/scripts/spartan/ui.py +++ b/scripts/spartan/ui.py @@ -312,7 +312,7 @@ class UI: # API authentication worker_api_auth_cbx = gradio.Checkbox(label='API Authentication') worker_user_field = gradio.Textbox(label='Username') - worker_password_field = gradio.Textbox(label='Password') + worker_password_field = gradio.Textbox(label='Password', type='password') update_credentials_btn = gradio.Button(value='Update API Credentials') update_credentials_btn.click(self.update_credentials_btn, inputs=[ worker_api_auth_cbx, diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index 9bc1057..cf92520 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -1,3 +1,4 @@ +import asyncio import base64 import copy import io @@ -156,6 +157,11 @@ class Worker: def __repr__(self): return f"'{self.label}'@{self.address}:{self.port}, speed: {self.avg_ipm} ipm, state: {self.state}" + def __eq__(self, other): + if isinstance(other, Worker) and other.label == self.label: + return True + return False + @property def model(self) -> Worker_Model: return Worker_Model(**self.__dict__) @@ -510,7 +516,7 @@ class Worker: t: Thread samples = 2 # number of times to benchmark the remote / accuracy - if self.state == State.DISABLED or self.state == State.UNAVAILABLE: + if self.state in (State.DISABLED, State.UNAVAILABLE): logger.debug(f"worker '{self.label}' is unavailable or disabled, refusing to benchmark") return 0 @@ -533,7 +539,6 @@ class Worker: results: List[float] = [] # it used to be lower for the first couple of generations # this was due to something torch does at startup according to auto and is now done at sdwui startup - self.state = State.WORKING for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up" if self.state == State.UNAVAILABLE: self.response = None @@ -677,6 +682,8 @@ class Worker: if vae is not None: self.loaded_vae = vae + return response + def restart(self) -> bool: err_msg = f"could not restart worker '{self.label}'" success_msg = f"worker '{self.label}' is restarting" diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index df6edd1..ae70006 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -18,6 +18,7 @@ from . import shared as sh from .pmodels import ConfigModel, Benchmark_Payload from .shared import logger, warmup_samples, extension_path from .worker import Worker, State +import asyncio class NotBenchmarked(Exception): @@ -231,38 +232,60 @@ class World: else: worker.benchmarked = True + tasks = [] + loop = asyncio.new_event_loop() # have every unbenched worker load the same weights before the benchmark for worker in unbenched_workers: if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE): continue - sync_thread = Thread(target=worker.load_options, args=(shared.opts.sd_model_checkpoint, shared.opts.sd_vae)) - sync_threads.append(sync_thread) - sync_thread.start() - for thread in sync_threads: - thread.join() + tasks.append( + loop.create_task( + asyncio.to_thread(worker.load_options, model=shared.opts.sd_model_checkpoint, vae=shared.opts.sd_vae) + , name=worker.label + ) + ) + if len(tasks) > 0: + results = loop.run_until_complete(asyncio.wait(tasks)) + for task in results[0]: + worker = self[task.get_name()] + response = task.result() + if response.status_code != 200: + logger.error(f"refusing to benchmark worker '{worker.label}' as it failed to load the selected model '{shared.opts.sd_model_checkpoint}'\n" + f"*you may circumvent this by using the per-worker model override setting but this is not recommended as the same benchmark model should be used for all workers") + unbenched_workers = list(filter(lambda w: w != worker, unbenched_workers)) # benchmark those that haven't been + tasks = [] for worker in unbenched_workers: if worker.state in (State.DISABLED, State.UNAVAILABLE): logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark") continue - t = Thread(target=benchmark_wrapped, args=(worker, ), name=f"{worker.label}_benchmark") - benchmark_threads.append(t) - t.start() + if worker.model_override is not None: + logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n" + f"*all workers should be evaluated against the same model") + + tasks.append( + loop.create_task( + asyncio.to_thread(benchmark_wrapped, worker), + name=worker.label + ) + ) logger.info(f"benchmarking worker '{worker.label}'") # wait for all benchmarks to finish and update stats on newly benchmarked workers - if len(benchmark_threads) > 0: - for t in benchmark_threads: - t.join() + if len(tasks) > 0: + results = loop.run_until_complete(asyncio.wait(tasks)) logger.info("benchmarking finished") + logger.debug(results) # save benchmark results to workers.json self.save_config() logger.info(self.speed_summary()) + loop.close() + def get_current_output_size(self) -> int: """ returns how many images would be returned from all jobs From 4ddb137b563c0abdeaadcb82810d56a4f6f7dffd Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 22 Mar 2024 16:26:00 -0500 Subject: [PATCH 09/20] use concurrent.futures for benchmarking --- scripts/spartan/worker.py | 4 +- scripts/spartan/world.py | 77 ++++++++++++++++----------------------- 2 files changed, 35 insertions(+), 46 deletions(-) diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index cf92520..2ff1c71 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -682,7 +682,9 @@ class Worker: if vae is not None: self.loaded_vae = vae - return response + self.response = response + + return self def restart(self) -> bool: err_msg = f"could not restart worker '{self.label}'" diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index ae70006..f1e238b 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -4,7 +4,7 @@ This module facilitates the creation of a stable-diffusion-webui centered distri World: The main class which should be instantiated in order to create a new sdwui distributed system. """ - +import concurrent.futures import copy import json import os @@ -210,8 +210,6 @@ class World: """ unbenched_workers = [] - benchmark_threads: List[Thread] = [] - sync_threads: List[Thread] = [] def benchmark_wrapped(worker): bench_func = worker.benchmark if not worker.master else self.benchmark_master @@ -232,60 +230,49 @@ class World: else: worker.benchmarked = True - tasks = [] - loop = asyncio.new_event_loop() - # have every unbenched worker load the same weights before the benchmark - for worker in unbenched_workers: - if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE): - continue + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] - tasks.append( - loop.create_task( - asyncio.to_thread(worker.load_options, model=shared.opts.sd_model_checkpoint, vae=shared.opts.sd_vae) - , name=worker.label + # have every unbenched worker load the same weights before the benchmark + for worker in unbenched_workers: + if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE): + continue + + futures.append( + executor.submit(worker.load_options, model=shared.opts.sd_model_checkpoint, vae=shared.opts.sd_vae) ) - ) - if len(tasks) > 0: - results = loop.run_until_complete(asyncio.wait(tasks)) - for task in results[0]: - worker = self[task.get_name()] - response = task.result() - if response.status_code != 200: - logger.error(f"refusing to benchmark worker '{worker.label}' as it failed to load the selected model '{shared.opts.sd_model_checkpoint}'\n" - f"*you may circumvent this by using the per-worker model override setting but this is not recommended as the same benchmark model should be used for all workers") - unbenched_workers = list(filter(lambda w: w != worker, unbenched_workers)) + for future in concurrent.futures.as_completed(futures): + worker = future.result() + if worker is None: + continue - # benchmark those that haven't been - tasks = [] - for worker in unbenched_workers: - if worker.state in (State.DISABLED, State.UNAVAILABLE): - logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark") - continue + if worker.response.status_code != 200: + logger.error(f"refusing to benchmark worker '{worker.label}' as it failed to load the selected model '{shared.opts.sd_model_checkpoint}'\n" + f"*you may circumvent this by using the per-worker model override setting but this is not recommended as the same benchmark model should be used for all workers") + unbenched_workers = list(filter(lambda w: w != worker, unbenched_workers)) + futures.clear() - if worker.model_override is not None: - logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n" - f"*all workers should be evaluated against the same model") + # benchmark those that haven't been + for worker in unbenched_workers: + if worker.state in (State.DISABLED, State.UNAVAILABLE): + logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark") + continue - tasks.append( - loop.create_task( - asyncio.to_thread(benchmark_wrapped, worker), - name=worker.label - ) - ) - logger.info(f"benchmarking worker '{worker.label}'") + if worker.model_override is not None: + logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n" + f"*all workers should be evaluated against the same model") - # wait for all benchmarks to finish and update stats on newly benchmarked workers - if len(tasks) > 0: - results = loop.run_until_complete(asyncio.wait(tasks)) + futures.append(executor.submit(benchmark_wrapped, worker)) + logger.info(f"benchmarking worker '{worker.label}'") + + # wait for all benchmarks to finish and update stats on newly benchmarked workers + concurrent.futures.wait(futures) logger.info("benchmarking finished") - logger.debug(results) # save benchmark results to workers.json self.save_config() logger.info(self.speed_summary()) - loop.close() - def get_current_output_size(self) -> int: """ returns how many images would be returned from all jobs From 7a96547151d6fd4abafd9358f1b14cda0adfc051 Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 25 Mar 2024 17:56:10 -0500 Subject: [PATCH 10/20] refactor to remove benchmark_master --- scripts/spartan/shared.py | 1 + scripts/spartan/worker.py | 32 +++++++++++--------- scripts/spartan/world.py | 63 ++++++++++++++------------------------- 3 files changed, 42 insertions(+), 54 deletions(-) diff --git a/scripts/spartan/shared.py b/scripts/spartan/shared.py index 26e7a2b..7487a52 100644 --- a/scripts/spartan/shared.py +++ b/scripts/spartan/shared.py @@ -61,6 +61,7 @@ logger.addHandler(gui_handler) # end logging warmup_samples = 2 # number of samples to do before recording a valid benchmark sample +samples = 2 # number of times to benchmark worker after warmup benchmarks are completed class BenchmarkPayload(BaseModel): diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index 2ff1c71..3b94cfc 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -507,21 +507,22 @@ class Worker: self.jobs_requested += 1 return - def benchmark(self) -> float: + def benchmark(self, sample_function: callable = None) -> float: """ given a worker, run a small benchmark and return its performance in images/minute makes standard request(s) of 512x512 images and averages them to get the result """ t: Thread - samples = 2 # number of times to benchmark the remote / accuracy if self.state in (State.DISABLED, State.UNAVAILABLE): logger.debug(f"worker '{self.label}' is unavailable or disabled, refusing to benchmark") return 0 if self.master is True: - return -1 + if sample_function is None: + logger.critical(f"no function provided for benchmarking master") + return -1 def ipm(seconds: float) -> float: """ @@ -539,18 +540,24 @@ class Worker: results: List[float] = [] # it used to be lower for the first couple of generations # this was due to something torch does at startup according to auto and is now done at sdwui startup - for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up" + for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up" if self.state == State.UNAVAILABLE: self.response = None return 0 - t = Thread(target=self.request, args=(dict(sh.benchmark_payload), None, False,), - name=f"{self.label}_benchmark_request") try: # if the worker is unreachable/offline then handle that here - t.start() - start = time.time() - t.join() - elapsed = time.time() - start + elapsed = None + + if sample_function is None: + start = time.time() + t = Thread(target=self.request, args=(dict(sh.benchmark_payload), None, False,), + name=f"{self.label}_benchmark_request") + t.start() + t.join() + elapsed = time.time() - start + else: + elapsed = sample_function() + sample_ipm = ipm(elapsed) except InvalidWorkerResponse as e: raise e @@ -563,10 +570,7 @@ class Worker: logger.debug(f"{self.label} finished warming up\n") # average the sample results for accuracy - ipm_sum = 0 - for ipm_result in results: - ipm_sum += ipm_result - avg_ipm_result = ipm_sum / samples + avg_ipm_result = sum(results) / sh.samples logger.debug(f"Worker '{self.label}' average ipm: {avg_ipm_result:.2f}") self.avg_ipm = avg_ipm_result diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index f1e238b..d5845c5 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -19,6 +19,7 @@ from .pmodels import ConfigModel, Benchmark_Payload from .shared import logger, warmup_samples, extension_path from .worker import Worker, State import asyncio +from modules.call_queue import wrap_queued_call class NotBenchmarked(Exception): @@ -204,18 +205,31 @@ class World: t = Thread(target=worker.refresh_checkpoints, args=()) t.start() + def sample_master(self) -> float: + # wrap our benchmark payload + master_bench_payload = StableDiffusionProcessingTxt2Img() + d = sh.benchmark_payload.dict() + for key in d: + setattr(master_bench_payload, key, d[key]) + + # Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to. + master_bench_payload.do_not_save_samples = True + # shared.state.begin(job='distributed_master_bench') + wrapped = (wrap_queued_call(process_images)) + start = time.time() + wrapped(master_bench_payload) + # wrap_gradio_gpu_call(process_images)(master_bench_payload) + # shared.state.end() + + return time.time() - start + + def benchmark(self, rebenchmark: bool = False): """ Attempts to benchmark all workers a part of the world. """ unbenched_workers = [] - - def benchmark_wrapped(worker): - bench_func = worker.benchmark if not worker.master else self.benchmark_master - worker.avg_ipm = bench_func() - worker.benchmarked = True - if rebenchmark: for worker in self._workers: worker.benchmarked = False @@ -230,7 +244,7 @@ class World: else: worker.benchmarked = True - with concurrent.futures.ThreadPoolExecutor() as executor: + with concurrent.futures.ThreadPoolExecutor(thread_name_prefix='distributed_benchmark') as executor: futures = [] # have every unbenched worker load the same weights before the benchmark @@ -262,7 +276,8 @@ class World: logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n" f"*all workers should be evaluated against the same model") - futures.append(executor.submit(benchmark_wrapped, worker)) + chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master) + futures.append(executor.submit(chosen, worker)) logger.info(f"benchmarking worker '{worker.label}'") # wait for all benchmarks to finish and update stats on newly benchmarked workers @@ -366,38 +381,6 @@ class World: return lag - def benchmark_master(self) -> float: - """ - Benchmarks the local/master worker. - - Returns: - float: Local worker speed in ipm - """ - - # wrap our benchmark payload - master_bench_payload = StableDiffusionProcessingTxt2Img() - d = sh.benchmark_payload.dict() - for key in d: - setattr(master_bench_payload, key, d[key]) - - # Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to. - master_bench_payload.do_not_save_samples = True - - # "warm up" due to initial generation lag - for _ in range(warmup_samples): - process_images(master_bench_payload) - - # get actual sample - start = time.time() - process_images(master_bench_payload) - elapsed = time.time() - start - - ipm = sh.benchmark_payload.batch_size / (elapsed / 60) - - logger.debug(f"Master benchmark took {elapsed:.2f}: {ipm:.2f} ipm") - self.master().benchmarked = True - return ipm - def update_jobs(self): """creates initial jobs (before optimization) """ From b5b4ca716dddc2ae0fc97f82a6efb9d64fe3a14f Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 24 Apr 2024 11:44:24 -0500 Subject: [PATCH 11/20] testing dynamic step downscaling --- scripts/distributed.py | 2 ++ scripts/spartan/shared.py | 2 +- scripts/spartan/worker.py | 26 ++++++++++++-------------- scripts/spartan/world.py | 28 +++++++++++++++++++++------- 4 files changed, 36 insertions(+), 22 deletions(-) diff --git a/scripts/distributed.py b/scripts/distributed.py index e8160f8..83b3a36 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -342,6 +342,8 @@ class Script(scripts.Script): prior_images += j.batch_size * p.n_iter payload_temp['batch_size'] = job.batch_size + if job.step_override is not None: + payload_temp['steps'] = job.step_override payload_temp['subseed'] += prior_images payload_temp['seed'] += prior_images if payload_temp['subseed_strength'] == 0 else 0 logger.debug( diff --git a/scripts/spartan/shared.py b/scripts/spartan/shared.py index 7487a52..52512b7 100644 --- a/scripts/spartan/shared.py +++ b/scripts/spartan/shared.py @@ -61,7 +61,7 @@ logger.addHandler(gui_handler) # end logging warmup_samples = 2 # number of samples to do before recording a valid benchmark sample -samples = 2 # number of times to benchmark worker after warmup benchmarks are completed +samples = 4 # number of times to benchmark worker after warmup benchmarks are completed class BenchmarkPayload(BaseModel): diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index 3b94cfc..cb07674 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -195,14 +195,14 @@ class Worker: protocol = 'http' if not self.tls else 'https' return f"{protocol}://{self.__str__()}/sdapi/v1/{route}" - def batch_eta_hr(self, payload: dict) -> float: + def eta_hr(self, payload: dict) -> float: """ takes a normal payload and returns the eta of a pseudo payload which mirrors the hr-fix parameters This returns the eta of how long it would take to run hr-fix on the original image """ pseudo_payload = copy.copy(payload) - pseudo_payload['enable_hr'] = False # prevent overflow in self.batch_eta + pseudo_payload['enable_hr'] = False # prevent overflow in self.eta res_ratio = pseudo_payload['hr_scale'] original_steps = pseudo_payload['steps'] second_pass_steps = pseudo_payload['hr_second_pass_steps'] @@ -218,12 +218,11 @@ class Worker: pseudo_payload['width'] = pseudo_width pseudo_payload['height'] = pseudo_height - eta = self.batch_eta(payload=pseudo_payload, quiet=True) - return eta + return self.eta(payload=pseudo_payload, quiet=True) - def batch_eta(self, payload: dict, quiet: bool = False, batch_size: int = None) -> float: + def eta(self, payload: dict, quiet: bool = False, batch_size: int = None, samples: int = None) -> float: """ - estimate how long it will take to generate images on a worker in seconds + estimate how long it will take to generate image(s) on a worker in seconds Args: payload: Sdwui api formatted payload @@ -231,7 +230,7 @@ class Worker: batch_size: Overrides the batch_size parameter of the payload """ - steps = payload['steps'] + steps = payload['steps'] if samples is None else samples num_images = payload['batch_size'] if batch_size is None else batch_size # if worker has not yet been benchmarked then @@ -243,7 +242,7 @@ class Worker: # show effect of high-res fix hr = payload.get('enable_hr', False) if hr: - eta += self.batch_eta_hr(payload=payload) + eta += self.eta_hr(payload=payload) # show effect of image size real_pix_to_benched = (payload['width'] * payload['height']) \ @@ -337,7 +336,7 @@ class Worker: self.load_options(model=option_payload['sd_model_checkpoint'], vae=option_payload['sd_vae']) if self.benchmarked: - eta = self.batch_eta(payload=payload) * payload['n_iter'] + eta = self.eta(payload=payload) * payload['n_iter'] logger.debug(f"worker '{self.label}' predicts it will take {eta:.3f}s to generate " f"{payload['batch_size'] * payload['n_iter']} image(s) " f"at a speed of {self.avg_ipm:.2f} ipm\n") @@ -519,10 +518,9 @@ class Worker: logger.debug(f"worker '{self.label}' is unavailable or disabled, refusing to benchmark") return 0 - if self.master is True: - if sample_function is None: - logger.critical(f"no function provided for benchmarking master") - return -1 + if self.master and sample_function is None: + logger.critical(f"no function provided for benchmarking master") + return -1 def ipm(seconds: float) -> float: """ @@ -548,7 +546,7 @@ class Worker: try: # if the worker is unreachable/offline then handle that here elapsed = None - if sample_function is None: + if not callable(sample_function): start = time.time() t = Thread(target=self.request, args=(dict(sh.benchmark_payload), None, False,), name=f"{self.label}_benchmark_request") diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index d5845c5..f9374d6 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -50,6 +50,7 @@ class Job: self.worker: Worker = worker self.batch_size: int = batch_size self.complementary: bool = False + self.step_override = None def __str__(self): prefix = '' @@ -194,16 +195,14 @@ class World: if worker.master: continue - t = Thread(target=worker.interrupt, args=()) - t.start() + Thread(target=worker.interrupt, args=()).start() def refresh_checkpoints(self): for worker in self.get_workers(): if worker.master: continue - t = Thread(target=worker.refresh_checkpoints, args=()) - t.start() + Thread(target=worker.refresh_checkpoints, args=()).start() def sample_master(self) -> float: # wrap our benchmark payload @@ -377,7 +376,7 @@ class World: if worker == fastest_worker: return 0 - lag = worker.batch_eta(payload=payload, quiet=True, batch_size=batch_size) - fastest_worker.batch_eta(payload=payload, quiet=True, batch_size=batch_size) + lag = worker.eta(payload=payload, quiet=True, batch_size=batch_size) - fastest_worker.eta(payload=payload, quiet=True, batch_size=batch_size) return lag @@ -518,11 +517,11 @@ class World: fastest_active = self.fastest_realtime_job().worker for j in self.jobs: if j.worker.label == fastest_active.label: - slack_time = fastest_active.batch_eta(payload=payload, batch_size=j.batch_size) + self.job_timeout + slack_time = fastest_active.eta(payload=payload, batch_size=j.batch_size) + self.job_timeout logger.debug(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.label}'") # see how long it would take to produce only 1 image on this complementary worker - secs_per_batch_image = job.worker.batch_eta(payload=payload, batch_size=1) + secs_per_batch_image = job.worker.eta(payload=payload, batch_size=1) num_images_compensate = int(slack_time / secs_per_batch_image) logger.debug( f"worker '{job.worker.label}':\n" @@ -535,6 +534,21 @@ class World: request_img_size = payload['width'] * payload['height'] max_images = job.worker.pixel_cap // request_img_size job.add_work(payload, batch_size=max_images) + + # when not even a singular image can be squeezed out + # if step scaling is enabled, then find how many samples would be considered realtime and adjust + if num_images_compensate == 0: + seconds_per_sample = job.worker.eta(payload=payload, batch_size=1, samples=1) + realtime_samples = slack_time // seconds_per_sample + logger.critical( + f"job for '{job.worker.label}' downscaled to {realtime_samples} samples to meet time constraints\n" + f"{realtime_samples:.0f} samples = {slack_time:.2f}s slack Γ· {seconds_per_sample:.2f}s/sample\n" + f" step reduction: {payload['steps']} -> {realtime_samples:.0f}" + ) + + job.add_work(payload=payload, batch_size=1) + job.step_override = realtime_samples + else: logger.debug("complementary image production is disabled") From dd7c6a0cfdc09979977b203f167c90adbe9fec05 Mon Sep 17 00:00:00 2001 From: unknown Date: Sat, 4 May 2024 05:06:09 -0500 Subject: [PATCH 12/20] add toggle for step scaling --- scripts/spartan/pmodels.py | 1 + scripts/spartan/ui.py | 17 +++++++++++++---- scripts/spartan/worker.py | 1 + scripts/spartan/world.py | 7 +++++-- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/scripts/spartan/pmodels.py b/scripts/spartan/pmodels.py index 7e2bf54..45c8c5b 100644 --- a/scripts/spartan/pmodels.py +++ b/scripts/spartan/pmodels.py @@ -42,3 +42,4 @@ class ConfigModel(BaseModel): job_timeout: Optional[int] = Field(default=3) enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True) complement_production: Optional[bool] = Field(description="Whether to generate complementary images to prevent under-utilizing hardware", default=True) + step_reduction: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints") diff --git a/scripts/spartan/ui.py b/scripts/spartan/ui.py index 4c710b6..bf7f7e7 100644 --- a/scripts/spartan/ui.py +++ b/scripts/spartan/ui.py @@ -81,7 +81,7 @@ class UI: return 'No active jobs!', worker_status, logs - def save_btn(self, thin_client_mode, job_timeout, complement_production): + def save_btn(self, thin_client_mode, job_timeout, complement_production, step_scaling): """updates the options visible on the settings tab""" self.world.thin_client_mode = thin_client_mode @@ -90,6 +90,7 @@ class UI: self.world.job_timeout = job_timeout logger.debug(f"job timeout is now {job_timeout} seconds") self.world.complement_production = complement_production + self.world.step_scaling = step_scaling self.world.save_config() def save_worker_btn(self, label, address, port, tls, disabled): @@ -365,13 +366,21 @@ class UI: complement_production = gradio.Checkbox( label='Complement production', - info='Prevents under-utilization of hardware by requesting additional images', + info='Prevents under-utilization by requesting additional images when possible', value=self.world.complement_production ) + # reduces image quality the more the sample-count must be reduced + # good for mixed setups where each worker may not be around the same speed + step_scaling = gradio.Checkbox( + label='Step scaling', + info='Prevents under-utilization via sample reduction in order to meet time constraints', + value=self.world.step_scaling + ) + save_btn = gradio.Button(value='Update') - save_btn.click(fn=self.save_btn, inputs=[thin_client_cbx, job_timeout, complement_production]) - components += [thin_client_cbx, job_timeout, complement_production, save_btn] + save_btn.click(fn=self.save_btn, inputs=[thin_client_cbx, job_timeout, complement_production, step_scaling]) + components += [thin_client_cbx, job_timeout, complement_production, step_scaling, save_btn] with gradio.Tab('Help'): gradio.Markdown( diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index cb07674..cfb36bb 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -574,6 +574,7 @@ class Worker: self.avg_ipm = avg_ipm_result self.response = None self.benchmarked = True + self.eta_percent_error = [] # likely inaccurate after rebenching self.state = State.IDLE return avg_ipm_result diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index f9374d6..b412d74 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -99,6 +99,7 @@ class World: self.enabled = True self.is_dropdown_handler_injected = False self.complement_production = True + self.step_scaling = False def __getitem__(self, label: str) -> Worker: for worker in self._workers: @@ -537,7 +538,7 @@ class World: # when not even a singular image can be squeezed out # if step scaling is enabled, then find how many samples would be considered realtime and adjust - if num_images_compensate == 0: + if num_images_compensate == 0 and self.step_scaling: seconds_per_sample = job.worker.eta(payload=payload, batch_size=1, samples=1) realtime_samples = slack_time // seconds_per_sample logger.critical( @@ -651,6 +652,7 @@ class World: self.job_timeout = config.job_timeout self.enabled = config.enabled self.complement_production = config.complement_production + self.step_scaling = config.step_reduction logger.debug("config loaded") @@ -664,7 +666,8 @@ class World: benchmark_payload=sh.benchmark_payload, job_timeout=self.job_timeout, enabled=self.enabled, - complement_production=self.complement_production + complement_production=self.complement_production, + step_scaling=self.step_scaling ) with open(self.config_path, 'w+') as config_file: From c4abe4d1f1f3b7c51871c78d388bb41ac30fc669 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 8 May 2024 07:59:47 -0500 Subject: [PATCH 13/20] fix thin-client mode and situations where no local generation will take place, avoid failure in some cases when pinging crashed workers --- scripts/distributed.py | 9 ++++++++- scripts/spartan/ui.py | 4 ++-- scripts/spartan/worker.py | 4 ++++ scripts/spartan/world.py | 18 +++++++++++++++++- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/scripts/distributed.py b/scripts/distributed.py index 83b3a36..b771250 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -235,6 +235,7 @@ class Script(scripts.Script): try: Script.world.initialize(batch_size) + Script.world.initial_payload = initial_payload logger.debug(f"World initialized!") except WorldAlreadyInitialized: Script.world.update_world(total_batch_size=batch_size) @@ -249,6 +250,9 @@ class Script(scripts.Script): current_thread().name = "distributed_main" Script.initialize(initial_payload=p) + # save original process_images_inner function for later if we monkeypatch it + Script.original_process_images_inner = processing.process_images_inner + # strip scripts that aren't yet supported and warn user packed_script_args: List[dict] = [] # list of api formatted per-script argument objects # { "script_name": { "args": ["value1", "value2", ...] } @@ -387,9 +391,12 @@ class Script(scripts.Script): if not Script.world.enabled: return - if len(processed.images) >= 1 and Script.master_start is not None: + if Script.master_start is not None: Script.add_to_gallery(p=p, processed=processed) + # restore process_images_inner if it was monkey-patched + processing.process_images_inner = Script.original_process_images_inner + @staticmethod def signal_handler(sig, frame): logger.debug("handling interrupt signal") diff --git a/scripts/spartan/ui.py b/scripts/spartan/ui.py index bf7f7e7..bbcdd74 100644 --- a/scripts/spartan/ui.py +++ b/scripts/spartan/ui.py @@ -354,8 +354,8 @@ class UI: with gradio.Tab('Settings'): thin_client_cbx = gradio.Checkbox( - label='Thin-client mode (experimental)', - info="(BROKEN) Only generate images using remote workers. There will be no previews when enabled.", + label='Thin-client mode', + info="Only generate images using remote workers. There will be no previews (yet) when enabled.", value=self.world.thin_client_mode ) job_timeout = gradio.Number( diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index cfb36bb..2dc8149 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -618,6 +618,10 @@ class Worker: except requests.exceptions.ConnectionError as e: logger.error(e) return False + except requests.ReadTimeout as e: + logger.critical(f"worker '{self.label}' is online but not responding (crashed?)") + logger.error(e) + return False def mark_unreachable(self): if self.state == State.DISABLED: diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index b412d74..494c6fe 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -18,8 +18,8 @@ from . import shared as sh from .pmodels import ConfigModel, Benchmark_Payload from .shared import logger, warmup_samples, extension_path from .worker import Worker, State -import asyncio from modules.call_queue import wrap_queued_call +from modules import processing class NotBenchmarked(Exception): @@ -565,6 +565,22 @@ class World: distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n" logger.info(distro_summary) + if self.thin_client_mode is True or self.master_job().batch_size == 0: + # save original process_images_inner for later so we can restore once we're done + logger.debug(f"bypassing local generation completely") + def process_images_inner_bypass(p) -> processing.Processed: + processed = processing.Processed(p, [], p.seed, info="") + processed.all_prompts = [] + processed.all_seeds = [] + processed.all_subseeds = [] + processed.all_negative_prompts = [] + processed.infotexts = [] + processed.prompt = None + + self.initial_payload.scripts.postprocess(p, processed) + return processed + processing.process_images_inner = process_images_inner_bypass + # delete any jobs that have no work last = len(self.jobs) - 1 while last > 0: From de18217aaf5da088865c6eafcea2574bd17f2b0c Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 9 May 2024 00:36:37 -0500 Subject: [PATCH 14/20] refactoring and cleanup --- scripts/distributed.py | 55 +++++++---------------------------- scripts/spartan/worker.py | 1 - scripts/spartan/world.py | 60 ++++++++++++++------------------------- 3 files changed, 32 insertions(+), 84 deletions(-) diff --git a/scripts/distributed.py b/scripts/distributed.py index b771250..59f755a 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -24,13 +24,12 @@ from modules.shared import state as webui_state from scripts.spartan.control_net import pack_control_net from scripts.spartan.shared import logger from scripts.spartan.ui import UI -from scripts.spartan.world import World, WorldAlreadyInitialized +from scripts.spartan.world import World old_sigint_handler = signal.getsignal(signal.SIGINT) old_sigterm_handler = signal.getsignal(signal.SIGTERM) -# TODO implement advertisement of some sort in sdwui api to allow extension to automatically discover workers? # noinspection PyMissingOrEmptyDocstring class Script(scripts.Script): worker_threads: List[Thread] = [] @@ -50,17 +49,8 @@ class Script(scripts.Script): urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # build world - world = World(initial_payload=None, verify_remotes=verify_remotes) - # add workers to the world + world = World(verify_remotes=verify_remotes) world.load_config() - if cmd_opts.distributed_remotes is not None and len(cmd_opts.distributed_remotes) > 0: - logger.warning(f"--distributed-remotes is deprecated and may be removed in the future\n" - f"gui/external modification of {world.config_path} will be prioritized going forward") - - for worker in cmd_opts.distributed_remotes: - world.add_worker(uuid=worker[0], address=worker[1], port=worker[2], tls=False) - world.save_config() - # do an early check to see which workers are online logger.info("doing initial ping sweep to see which workers are reachable") world.ping_remotes(indiscriminate=True) @@ -71,7 +61,7 @@ class Script(scripts.Script): return scripts.AlwaysVisible def ui(self, is_img2img): - self.world.load_config() + Script.world.load_config() extension_ui = UI(script=Script, world=Script.world) # root, api_exposed = extension_ui.create_ui() components = extension_ui.create_ui() @@ -128,7 +118,7 @@ class Script(scripts.Script): if p.n_iter > 1: # if splitting by batch count num_remote_images *= p.n_iter - 1 - logger.debug(f"image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, " + logger.debug(f"image {true_image_pos + 1}/{Script.world.p.batch_size * p.n_iter}, " f"info-index: {info_index}") if Script.world.thin_client_mode: @@ -225,30 +215,15 @@ class Script(scripts.Script): p.batch_size = len(processed.images) return - @staticmethod - def initialize(initial_payload): - # get default batch size - try: - batch_size = initial_payload.batch_size - except AttributeError: - batch_size = 1 - - try: - Script.world.initialize(batch_size) - Script.world.initial_payload = initial_payload - logger.debug(f"World initialized!") - except WorldAlreadyInitialized: - Script.world.update_world(total_batch_size=batch_size) # p's type is - # "modules.processing.StableDiffusionProcessingTxt2Img" - def before_process(self, p, *args): - if not self.world.enabled: + # "modules.processing.StableDiffusionProcessing*" + @staticmethod + def before_process(p, *args): + if not Script.world.enabled: logger.debug("extension is disabled") return - - current_thread().name = "distributed_main" - Script.initialize(initial_payload=p) + Script.world.update(p) # save original process_images_inner function for later if we monkeypatch it Script.original_process_images_inner = processing.process_images_inner @@ -367,22 +342,12 @@ class Script(scripts.Script): started_jobs.append(job) # if master batch size was changed again due to optimization change it to the updated value - if not self.world.thin_client_mode: + if not Script.world.thin_client_mode: p.batch_size = Script.world.master_job().batch_size Script.master_start = time.time() # generate images assigned to local machine p.do_not_save_grid = True # don't generate grid from master as we are doing this later. - if Script.world.thin_client_mode: - p.batch_size = 0 - processed = Processed(p=p, images_list=[]) - processed.all_prompts = [] - processed.all_seeds = [] - processed.all_subseeds = [] - processed.all_negative_prompts = [] - processed.infotexts = [] - processed.prompt = None - Script.runs_since_init += 1 return diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index 2dc8149..39cf884 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -1,4 +1,3 @@ -import asyncio import base64 import copy import io diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 494c6fe..5156cc6 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -30,13 +30,6 @@ class NotBenchmarked(Exception): pass -class WorldAlreadyInitialized(Exception): - """ - Raised when attempting to initialize the World when it has already been initialized. - """ - pass - - class Job: """ Keeps track of how much work a given worker should handle. @@ -78,7 +71,7 @@ class World: The frame or "world" which holds all workers (including the local machine). Args: - initial_payload: The original txt2img payload created by the user initiating the generation request on master. + p: The original processing state object created by the user initiating the generation request on master. verify_remotes (bool): Whether to validate remote worker certificates. """ @@ -86,15 +79,14 @@ class World: config_path = shared.cmd_opts.distributed_config old_config_path = worker_info_path = extension_path.joinpath('workers.json') - def __init__(self, initial_payload, verify_remotes: bool = True): + def __init__(self, verify_remotes: bool = True): + self.p = None self.master_worker = Worker(master=True) - self.total_batch_size: int = 0 self._workers: List[Worker] = [self.master_worker] self.jobs: List[Job] = [] self.job_timeout: int = 3 # seconds self.initialized: bool = False self.verify_remotes = verify_remotes - self.initial_payload = copy.copy(initial_payload) self.thin_client_mode = False self.enabled = True self.is_dropdown_handler_injected = False @@ -109,32 +101,12 @@ class World: def __repr__(self): return f"{len(self._workers)} workers" - def update_world(self, total_batch_size): - """ - Updates the world with information vital to handling the local generation request after - the world has already been initialized. - - Args: - total_batch_size (int): The total number of images requested by the local/master sdwui instance. - """ - - self.total_batch_size = total_batch_size - self.update_jobs() - - def initialize(self, total_batch_size): - """should be called before a world instance is used for anything""" - if self.initialized: - raise WorldAlreadyInitialized("This world instance was already initialized") - - self.benchmark() - self.update_world(total_batch_size=total_batch_size) - self.initialized = True def default_batch_size(self) -> int: """the amount of images/total images requested that a worker would compute if conditions were perfect and each worker generated at the same speed. assumes one batch only""" - return self.total_batch_size // self.size() + return self.p.batch_size // self.size() def size(self) -> int: """ @@ -396,6 +368,18 @@ class World: self.jobs.append(Job(worker=worker, batch_size=batch_size)) + def update(self, p): + """preps world for another run""" + if not self.initialized: + self.benchmark() + self.initialized = True + logger.debug("world initialized!") + else: + logger.debug("world was already initialized") + + self.p = p + self.update_jobs() + def get_workers(self): filtered: List[Worker] = [] for worker in self._workers: @@ -431,7 +415,7 @@ class World: logger.debug(f"worker '{job.worker.label}' would stall the image gallery by ~{lag:.2f}s\n") job.complementary = True - if deferred_images + images_checked + payload['batch_size'] > self.total_batch_size: + if deferred_images + images_checked + payload['batch_size'] > self.p.batch_size: logger.debug(f"would go over actual requested size") else: deferred_images += payload['batch_size'] @@ -474,9 +458,9 @@ class World: ####################### # when total number of requested images was not cleanly divisible by world size then we tack the remainder on - remainder_images = self.total_batch_size - self.get_current_output_size() + remainder_images = self.p.batch_size - self.get_current_output_size() if remainder_images >= 1: - logger.debug(f"The requested number of images({self.total_batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed") + logger.debug(f"The requested number of images({self.p.batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed") realtime_jobs = self.realtime_jobs() realtime_jobs.sort(key=lambda x: x.batch_size) @@ -555,9 +539,9 @@ class World: iterations = payload['n_iter'] num_returning = self.get_current_output_size() - num_complementary = num_returning - self.total_batch_size + num_complementary = num_returning - self.p.batch_size distro_summary = "Job distribution:\n" - distro_summary += f"{self.total_batch_size} * {iterations} iteration(s)" + distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)" if num_complementary > 0: distro_summary += f" + {num_complementary} complementary" distro_summary += f": {num_returning} images total\n" @@ -577,7 +561,7 @@ class World: processed.infotexts = [] processed.prompt = None - self.initial_payload.scripts.postprocess(p, processed) + self.p.scripts.postprocess(p, processed) return processed processing.process_images_inner = process_images_inner_bypass From 0ac23cbd73b8cd9bdd607d4d0d55fa6236f73eb7 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 9 May 2024 01:44:45 -0500 Subject: [PATCH 15/20] refactoring --- scripts/distributed.py | 110 ++++++++++++++++++++--------------------- 1 file changed, 54 insertions(+), 56 deletions(-) diff --git a/scripts/distributed.py b/scripts/distributed.py index 59f755a..ec92e0c 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -32,27 +32,32 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM) # noinspection PyMissingOrEmptyDocstring class Script(scripts.Script): - worker_threads: List[Thread] = [] - # Whether to verify worker certificates. Can be useful if your remotes are self-signed. - verify_remotes = not cmd_opts.distributed_skip_verify_remotes - is_img2img = True - is_txt2img = True - alwayson = True - master_start = None - runs_since_init = 0 - name = "distributed" - is_dropdown_handler_injected = False + def __init__(self): + super().__init__() + self.worker_threads: List[Thread] = [] + # Whether to verify worker certificates. Can be useful if your remotes are self-signed. + self.verify_remotes = not cmd_opts.distributed_skip_verify_remotes + self.is_img2img = True + self.is_txt2img = True + self.alwayson = True + self.master_start = None + self.runs_since_init = 0 + self.name = "distributed" + self.is_dropdown_handler_injected = False - if verify_remotes is False: - logger.warning(f"You have chosen to forego the verification of worker TLS certificates") - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + if self.verify_remotes is False: + logger.warning(f"You have chosen to forego the verification of worker TLS certificates") + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - # build world - world = World(verify_remotes=verify_remotes) - world.load_config() - logger.info("doing initial ping sweep to see which workers are reachable") - world.ping_remotes(indiscriminate=True) + # build world + self.world = World(verify_remotes=self.verify_remotes) + self.world.load_config() + logger.info("doing initial ping sweep to see which workers are reachable") + self.world.ping_remotes(indiscriminate=True) + + signal.signal(signal.SIGINT, self.signal_handler) + signal.signal(signal.SIGTERM, self.signal_handler) def title(self): return "Distribute" @@ -61,19 +66,18 @@ class Script(scripts.Script): return scripts.AlwaysVisible def ui(self, is_img2img): - Script.world.load_config() - extension_ui = UI(script=Script, world=Script.world) + self.world.load_config() + extension_ui = UI(script=Script, world=self.world) # root, api_exposed = extension_ui.create_ui() components = extension_ui.create_ui() # The first injection of handler for the models dropdown(sd_model_checkpoint) which is often present # in the quick-settings bar of a user. Helps ensure model swaps propagate to all nodes ASAP. - Script.world.inject_model_dropdown_handler() + self.world.inject_model_dropdown_handler() # return some components that should be exposed to the api return components - @staticmethod - def add_to_gallery(processed, p): + def add_to_gallery(self, processed, p): """adds generated images to the image gallery after waiting for all workers to finish""" def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None): @@ -118,10 +122,10 @@ class Script(scripts.Script): if p.n_iter > 1: # if splitting by batch count num_remote_images *= p.n_iter - 1 - logger.debug(f"image {true_image_pos + 1}/{Script.world.p.batch_size * p.n_iter}, " + logger.debug(f"image {true_image_pos + 1}/{self.world.p.batch_size * p.n_iter}, " f"info-index: {info_index}") - if Script.world.thin_client_mode: + if self.world.thin_client_mode: p.all_negative_prompts = processed.all_negative_prompts try: @@ -146,21 +150,21 @@ class Script(scripts.Script): ) # get master ipm by estimating based on worker speed - master_elapsed = time.time() - Script.master_start + master_elapsed = time.time() - self.master_start logger.debug(f"Took master {master_elapsed:.2f}s") # wait for response from all workers webui_state.textinfo = "Distributed - receiving results" - for thread in Script.worker_threads: + for thread in self.worker_threads: logger.debug(f"waiting for worker thread '{thread.name}'") thread.join() - Script.worker_threads.clear() + self.worker_threads.clear() logger.debug("all worker request threads returned") webui_state.textinfo = "Distributed - injecting images" # some worker which we know has a good response that we can use for generating the grid donor_worker = None - for job in Script.world.jobs: + for job in self.world.jobs: if job.batch_size < 1 or job.worker.master: continue @@ -209,7 +213,7 @@ class Script(scripts.Script): ) # cleanup after we're doing using all the responses - for worker in Script.world.get_workers(): + for worker in self.world.get_workers(): worker.response = None p.batch_size = len(processed.images) @@ -218,15 +222,14 @@ class Script(scripts.Script): # p's type is # "modules.processing.StableDiffusionProcessing*" - @staticmethod - def before_process(p, *args): - if not Script.world.enabled: + def before_process(self, p, *args): + if not self.world.enabled: logger.debug("extension is disabled") return - Script.world.update(p) + self.world.update(p) # save original process_images_inner function for later if we monkeypatch it - Script.original_process_images_inner = processing.process_images_inner + self.original_process_images_inner = processing.process_images_inner # strip scripts that aren't yet supported and warn user packed_script_args: List[dict] = [] # list of api formatted per-script argument objects @@ -262,7 +265,7 @@ class Script(scripts.Script): # encapsulating the request object within a txt2imgreq object is deprecated and no longer works # see test/basic_features/txt2img_test.py for an example payload = copy.copy(p.__dict__) - payload['batch_size'] = Script.world.default_batch_size() + payload['batch_size'] = self.world.default_batch_size() payload['scripts'] = None try: del payload['script_args'] @@ -291,11 +294,11 @@ class Script(scripts.Script): # start generating images assigned to remote machines sync = False # should only really need to sync once per job - Script.world.optimize_jobs(payload) # optimize work assignment before dispatching + self.world.optimize_jobs(payload) # optimize work assignment before dispatching started_jobs = [] # check if anything even needs to be done - if len(Script.world.jobs) == 1 and Script.world.jobs[0].worker.master: + if len(self.world.jobs) == 1 and self.world.jobs[0].worker.master: if payload['batch_size'] >= 2: msg = f"all remote workers are offline or unreachable" @@ -306,7 +309,7 @@ class Script(scripts.Script): return - for job in Script.world.jobs: + for job in self.world.jobs: payload_temp = copy.copy(payload) del payload_temp['scripts_value'] payload_temp = copy.deepcopy(payload_temp) @@ -338,35 +341,33 @@ class Script(scripts.Script): name=f"{job.worker.label}_request") t.start() - Script.worker_threads.append(t) + self.worker_threads.append(t) started_jobs.append(job) # if master batch size was changed again due to optimization change it to the updated value - if not Script.world.thin_client_mode: - p.batch_size = Script.world.master_job().batch_size - Script.master_start = time.time() + if not self.world.thin_client_mode: + p.batch_size = self.world.master_job().batch_size + self.master_start = time.time() # generate images assigned to local machine p.do_not_save_grid = True # don't generate grid from master as we are doing this later. - Script.runs_since_init += 1 + self.runs_since_init += 1 return - @staticmethod - def postprocess(p, processed, *args): - if not Script.world.enabled: + def postprocess(self, p, processed, *args): + if not self.world.enabled: return - if Script.master_start is not None: - Script.add_to_gallery(p=p, processed=processed) + if self.master_start is not None: + self.add_to_gallery(p=p, processed=processed) # restore process_images_inner if it was monkey-patched - processing.process_images_inner = Script.original_process_images_inner + processing.process_images_inner = self.original_process_images_inner - @staticmethod - def signal_handler(sig, frame): + def signal_handler(self, sig, frame): logger.debug("handling interrupt signal") # do cleanup - Script.world.save_config() + self.world.save_config() if sig == signal.SIGINT: if callable(old_sigint_handler): @@ -376,6 +377,3 @@ class Script(scripts.Script): old_sigterm_handler(sig, frame) else: sys.exit(0) - - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) From 574afa409f71b974f1cf2134de76925c41127fc6 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 10 May 2024 10:15:49 -0500 Subject: [PATCH 16/20] step_scaling load/save fix, logging and semantics --- scripts/distributed.py | 5 +++-- scripts/spartan/pmodels.py | 2 +- scripts/spartan/shared.py | 2 +- scripts/spartan/ui.py | 10 +++++----- scripts/spartan/world.py | 9 +++++---- 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/scripts/distributed.py b/scripts/distributed.py index ec92e0c..e5f84a9 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -10,7 +10,7 @@ import re import signal import sys import time -from threading import Thread, current_thread +from threading import Thread from typing import List import gradio import urllib3 @@ -330,7 +330,8 @@ class Script(scripts.Script): payload_temp['seed'] += prior_images if payload_temp['subseed_strength'] == 0 else 0 logger.debug( f"'{job.worker.label}' job's given starting seed is " - f"{payload_temp['seed']} with {prior_images} coming before it") + f"{payload_temp['seed']} with {prior_images} coming before it" + ) if job.worker.loaded_model != name or job.worker.loaded_vae != vae: sync = True diff --git a/scripts/spartan/pmodels.py b/scripts/spartan/pmodels.py index 45c8c5b..e349dc8 100644 --- a/scripts/spartan/pmodels.py +++ b/scripts/spartan/pmodels.py @@ -42,4 +42,4 @@ class ConfigModel(BaseModel): job_timeout: Optional[int] = Field(default=3) enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True) complement_production: Optional[bool] = Field(description="Whether to generate complementary images to prevent under-utilizing hardware", default=True) - step_reduction: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints") + step_scaling: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints", default=False) diff --git a/scripts/spartan/shared.py b/scripts/spartan/shared.py index 52512b7..5400464 100644 --- a/scripts/spartan/shared.py +++ b/scripts/spartan/shared.py @@ -61,7 +61,7 @@ logger.addHandler(gui_handler) # end logging warmup_samples = 2 # number of samples to do before recording a valid benchmark sample -samples = 4 # number of times to benchmark worker after warmup benchmarks are completed +samples = 3 # number of times to benchmark worker after warmup benchmarks are completed class BenchmarkPayload(BaseModel): diff --git a/scripts/spartan/ui.py b/scripts/spartan/ui.py index bbcdd74..d0f4d27 100644 --- a/scripts/spartan/ui.py +++ b/scripts/spartan/ui.py @@ -85,10 +85,8 @@ class UI: """updates the options visible on the settings tab""" self.world.thin_client_mode = thin_client_mode - logger.debug(f"thin client mode is now {thin_client_mode}") job_timeout = int(job_timeout) self.world.job_timeout = job_timeout - logger.debug(f"job timeout is now {job_timeout} seconds") self.world.complement_production = complement_production self.world.step_scaling = step_scaling self.world.save_config() @@ -216,7 +214,6 @@ class UI: interactive=True ) main_toggle.input(self.main_toggle_btn) - setattr(main_toggle, 'do_not_save_to_config', True) # ui_loadsave.py apply_field() components.append(main_toggle) with gradio.Tab('Status') as status_tab: @@ -355,13 +352,13 @@ class UI: with gradio.Tab('Settings'): thin_client_cbx = gradio.Checkbox( label='Thin-client mode', - info="Only generate images using remote workers. There will be no previews (yet) when enabled.", + info="Only generate images remotely (no image previews yet)", value=self.world.thin_client_mode ) job_timeout = gradio.Number( label='Job timeout', value=self.world.job_timeout, info="Seconds until a worker is considered too slow to be assigned an" - " equal share of the total request. Longer than 2 seconds is recommended." + " equal share of the total request. Longer than 2 seconds is recommended" ) complement_production = gradio.Checkbox( @@ -390,4 +387,7 @@ class UI: """ ) + # prevent wui from overriding any values + for component in components: + setattr(component, 'do_not_save_to_config', True) # ui_loadsave.py apply_field() return components diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 5156cc6..1b2bfcb 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -353,7 +353,7 @@ class World: return lag - def update_jobs(self): + def make_jobs(self): """creates initial jobs (before optimization) """ # clear jobs if this is not the first time running @@ -367,6 +367,7 @@ class World: worker.benchmark() self.jobs.append(Job(worker=worker, batch_size=batch_size)) + logger.debug(f"added job for worker {worker.label}") def update(self, p): """preps world for another run""" @@ -378,7 +379,7 @@ class World: logger.debug("world was already initialized") self.p = p - self.update_jobs() + self.make_jobs() def get_workers(self): filtered: List[Worker] = [] @@ -525,7 +526,7 @@ class World: if num_images_compensate == 0 and self.step_scaling: seconds_per_sample = job.worker.eta(payload=payload, batch_size=1, samples=1) realtime_samples = slack_time // seconds_per_sample - logger.critical( + logger.debug( f"job for '{job.worker.label}' downscaled to {realtime_samples} samples to meet time constraints\n" f"{realtime_samples:.0f} samples = {slack_time:.2f}s slack Γ· {seconds_per_sample:.2f}s/sample\n" f" step reduction: {payload['steps']} -> {realtime_samples:.0f}" @@ -652,7 +653,7 @@ class World: self.job_timeout = config.job_timeout self.enabled = config.enabled self.complement_production = config.complement_production - self.step_scaling = config.step_reduction + self.step_scaling = config.step_scaling logger.debug("config loaded") From 424a1c88580a379b7221c91f0ae472d26764ec88 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 10 May 2024 14:51:53 -0500 Subject: [PATCH 17/20] remove redundant init after last ref --- scripts/distributed.py | 59 ++++++++++++++++++++-------------------- scripts/spartan/ui.py | 3 +- scripts/spartan/world.py | 2 +- 3 files changed, 31 insertions(+), 33 deletions(-) diff --git a/scripts/distributed.py b/scripts/distributed.py index e5f84a9..2fe90a7 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -18,7 +18,7 @@ from PIL import Image from modules import processing from modules import scripts from modules.images import save_image -from modules.processing import fix_seed, Processed +from modules.processing import fix_seed from modules.shared import opts, cmd_opts from modules.shared import state as webui_state from scripts.spartan.control_net import pack_control_net @@ -31,33 +31,29 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM) # noinspection PyMissingOrEmptyDocstring -class Script(scripts.Script): +class DistributedScript(scripts.Script): + # global old_sigterm_handler, old_sigterm_handler + worker_threads: List[Thread] = [] + # Whether to verify worker certificates. Can be useful if your remotes are self-signed. + verify_remotes = not cmd_opts.distributed_skip_verify_remotes + master_start = None + runs_since_init = 0 + name = "distributed" + is_dropdown_handler_injected = False + if verify_remotes is False: + logger.warning(f"You have chosen to forego the verification of worker TLS certificates") + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + # build world + world = World(verify_remotes=verify_remotes) + world.load_config() + logger.info("doing initial ping sweep to see which workers are reachable") + world.ping_remotes(indiscriminate=True) + + # constructed for both txt2img and img2img def __init__(self): super().__init__() - self.worker_threads: List[Thread] = [] - # Whether to verify worker certificates. Can be useful if your remotes are self-signed. - self.verify_remotes = not cmd_opts.distributed_skip_verify_remotes - self.is_img2img = True - self.is_txt2img = True - self.alwayson = True - self.master_start = None - self.runs_since_init = 0 - self.name = "distributed" - self.is_dropdown_handler_injected = False - - if self.verify_remotes is False: - logger.warning(f"You have chosen to forego the verification of worker TLS certificates") - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - - # build world - self.world = World(verify_remotes=self.verify_remotes) - self.world.load_config() - logger.info("doing initial ping sweep to see which workers are reachable") - self.world.ping_remotes(indiscriminate=True) - - signal.signal(signal.SIGINT, self.signal_handler) - signal.signal(signal.SIGTERM, self.signal_handler) def title(self): return "Distribute" @@ -67,7 +63,7 @@ class Script(scripts.Script): def ui(self, is_img2img): self.world.load_config() - extension_ui = UI(script=Script, world=self.world) + extension_ui = UI(world=self.world) # root, api_exposed = extension_ui.create_ui() components = extension_ui.create_ui() @@ -203,7 +199,7 @@ class Script(scripts.Script): # generate and inject grid if opts.return_grid: - grid = processing.images.image_grid(processed.images, len(processed.images)) + grid = images.image_grid(processed.images, len(processed.images)) processed_inject_image( image=grid, info_index=0, @@ -219,7 +215,6 @@ class Script(scripts.Script): p.batch_size = len(processed.images) return - # p's type is # "modules.processing.StableDiffusionProcessing*" def before_process(self, p, *args): @@ -365,10 +360,11 @@ class Script(scripts.Script): # restore process_images_inner if it was monkey-patched processing.process_images_inner = self.original_process_images_inner - def signal_handler(self, sig, frame): + @staticmethod + def signal_handler(sig, frame): logger.debug("handling interrupt signal") # do cleanup - self.world.save_config() + DistributedScript.world.save_config() if sig == signal.SIGINT: if callable(old_sigint_handler): @@ -378,3 +374,6 @@ class Script(scripts.Script): old_sigterm_handler(sig, frame) else: sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) diff --git a/scripts/spartan/ui.py b/scripts/spartan/ui.py index d0f4d27..cd0b87a 100644 --- a/scripts/spartan/ui.py +++ b/scripts/spartan/ui.py @@ -14,8 +14,7 @@ worker_select_dropdown = None class UI: """extension user interface related things""" - def __init__(self, script, world): - self.script = script + def __init__(self, world): self.world = world self.original_model_dropdown_handler = opts.data_labels.get('sd_model_checkpoint').onchange diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 1b2bfcb..2c72a28 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -16,7 +16,7 @@ import modules.shared as shared from modules.processing import process_images, StableDiffusionProcessingTxt2Img from . import shared as sh from .pmodels import ConfigModel, Benchmark_Payload -from .shared import logger, warmup_samples, extension_path +from .shared import logger, extension_path from .worker import Worker, State from modules.call_queue import wrap_queued_call from modules import processing From 56d667dfe8fc5266e39b778f88c4811bfa609e34 Mon Sep 17 00:00:00 2001 From: unknown Date: Sat, 11 May 2024 21:57:02 -0500 Subject: [PATCH 18/20] protect against cyclical configurations --- scripts/distributed.py | 4 +++- scripts/spartan/world.py | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/scripts/distributed.py b/scripts/distributed.py index 2fe90a7..5d2208a 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -62,7 +62,9 @@ class DistributedScript(scripts.Script): return scripts.AlwaysVisible def ui(self, is_img2img): - self.world.load_config() + if not is_img2img: # prevents loading twice for no reason + self.world.load_config() + extension_ui = UI(world=self.world) # root, api_exposed = extension_ui.create_ui() components = extension_ui.create_ui() diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 2c72a28..f555ca1 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -145,6 +145,14 @@ class World: Worker: The worker object. """ + # protect against user trying to make cyclical setups and connections + is_master = kwargs.get('master') + if is_master is None or not is_master: + m = self.master() + if kwargs['address'] == m.address and kwargs['port'] == m.port: + logger.error(f"refusing to add worker {kwargs['label']} as its socket definition({m.address}:{m.port}) matches master") + return None + original = self[kwargs['label']] # if worker doesn't already exist then just make a new one if original is None: new = Worker(**kwargs) From eef13b03daf373e5ed0acdb4c938e803201e5020 Mon Sep 17 00:00:00 2001 From: unknown Date: Sat, 11 May 2024 23:33:15 -0500 Subject: [PATCH 19/20] should no longer need early load from ui --- scripts/distributed.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/scripts/distributed.py b/scripts/distributed.py index 5d2208a..e3d88f8 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -62,9 +62,6 @@ class DistributedScript(scripts.Script): return scripts.AlwaysVisible def ui(self, is_img2img): - if not is_img2img: # prevents loading twice for no reason - self.world.load_config() - extension_ui = UI(world=self.world) # root, api_exposed = extension_ui.create_ui() components = extension_ui.create_ui() From 515979653c66cf24879e1e2071813c68a0547c85 Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 12 May 2024 00:22:15 -0500 Subject: [PATCH 20/20] update changelog --- CHANGELOG.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index da76536..bd16d48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,27 @@ # Change Log Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html) +## [2.2.0] - 2024-5-11 + +### Added +- Toggle for allowing automatic step scaling which can increase overall utilization + +### Changed +- Adding workers which have the same socket definition as master will no longer be allowed and an error will show #28 +- Workers in an invalid state should no longer be benchmarked +- The worker port under worker config will now default to 7860 to prevent mishaps +- Config should once again only be loaded once per session startup +- A warning will be shown when trying to use the user script button but no script exists + +### Fixed +- Thin-client mode +- Some problems with sdwui forge branch +- Certificate verification setting sometimes not saving +- Master being assigned no work stopping generation (same problem as thin-client) + +### Removed +- Adding workers using deprecated cmdline argument + ## [2.1.0] - 2024-3-03 ### Added