diff --git a/preload.py b/preload.py index a8c4977..89be003 100644 --- a/preload.py +++ b/preload.py @@ -23,4 +23,3 @@ def preload(parser): help="Enable debug information", action="store_true" ) - # args = parser.parse_args() diff --git a/scripts/extension.py b/scripts/extension.py index 0579550..b1b84f9 100644 --- a/scripts/extension.py +++ b/scripts/extension.py @@ -21,7 +21,7 @@ from pathlib import Path import os import subprocess from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized -from scripts.spartan.Worker import Worker +from scripts.spartan.Worker import Worker, State from modules.shared import opts # TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers? @@ -60,13 +60,16 @@ class Script(scripts.Script): with gradio.Tab('Status') as status_tab: status = gradio.Textbox(elem_id='status', show_label=False) - status_tab.select(fn=Script.ui_connect_test, inputs=[], outputs=[status]) + status_tab.select(fn=Script.ui_connect_status, inputs=[], outputs=[status]) + + jobs = gradio.Textbox(elem_id='jobs', label='Jobs', show_label=True) + # status_tab.select(fn=Script.world.__str__, inputs=[], outputs=[jobs, status]), refresh_status_btn = gradio.Button(value='Refresh') refresh_status_btn.style(size='sm') - refresh_status_btn.click(Script.ui_connect_test, inputs=[], outputs=[status]) + refresh_status_btn.click(Script.ui_connect_status, inputs=[], outputs=[jobs, status]) - with gradio.Tab('Remote Utils'): + with gradio.Tab('Utils'): refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints') refresh_checkpoints_btn.style(full_width=False) refresh_checkpoints_btn.click(Script.ui_connect_refresh_ckpts_btn, inputs=[], outputs=[]) @@ -79,8 +82,19 @@ class Script(scripts.Script): interrupt_all_btn.style(full_width=False) interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[]) + # redo benchmarks button + redo_benchmarks_btn = gradio.Button(value='Redo benchmarks') + redo_benchmarks_btn.style(full_width=False) + redo_benchmarks_btn.click(Script.ui_connect_benchmark_button, inputs=[], outputs=[]) + + return + @staticmethod + def ui_connect_benchmark_button(): + print("Redoing benchmarks...") + Script.world.benchmark(rebenchmark=True) + @staticmethod def user_sync_script(): user_scripts = Path(os.path.abspath(__file__)).parent.joinpath('user') @@ -120,23 +134,29 @@ class Script(scripts.Script): print("Distributed system not initialized") @staticmethod - def ui_connect_test(): + def ui_connect_status(): try: - temp = '' + worker_status = '' for worker in Script.world.workers: if worker.master: continue - temp += f"{worker.uuid} at {worker.address} is {worker.state.name}\n" + worker_status += f"{worker.uuid} at {worker.address} is {worker.state.name}\n" - return temp + # TODO replace this with a single check to a state flag that we should make in the world class + for worker in Script.world.workers: + if worker.state == State.WORKING: + return Script.world.__str__(), worker_status + + return 'No active jobs!', worker_status # init system if it isn't already - except AttributeError: - # batch size will be clobbered later once an actual request is made anyway so I just pass 1 + except AttributeError as e: + print(e) + # batch size will be clobbered later once an actual request is made anyway Script.initialize(initial_payload=None) - Script.ui_connect_test() + return 'refresh!', 'refresh!' @staticmethod @@ -222,33 +242,38 @@ class Script(scripts.Script): @staticmethod def initialize(initial_payload): - if Script.verify_remotes is False: - print(f"WARNING: you have chosen to forego the verification of worker TLS certificates") - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - + # get default batch size try: batch_size = initial_payload.batch_size except AttributeError: batch_size = 1 - Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes) + if Script.world is None: + if Script.verify_remotes is False: + print(f"WARNING: you have chosen to forego the verification of worker TLS certificates") + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - # add workers to the world - for worker in cmd_opts.distributed_remotes: - Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2]) + # construct World + Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes) - try: - Script.world.initialize(batch_size) - print("World initialized!") - except WorldAlreadyInitialized: - Script.world.update_world(total_batch_size=batch_size) + # add workers to the world + for worker in cmd_opts.distributed_remotes: + Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2]) + + else: + # update world or initialize and update if necessary + try: + Script.world.initialize(batch_size) + print("World initialized!") + except WorldAlreadyInitialized: + Script.world.update_world(total_batch_size=batch_size) def run(self, p, *args): if cmd_opts.distributed_remotes is None: raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)") # register gallery callback - script_callbacks.on_after_image_processed(Script.add_to_gallery) + script_callbacks.on_after_batch_processed(Script.add_to_gallery) Script.initialize(initial_payload=p) diff --git a/scripts/spartan/Worker.py b/scripts/spartan/Worker.py index 4f10f6a..89e71c7 100644 --- a/scripts/spartan/Worker.py +++ b/scripts/spartan/Worker.py @@ -121,7 +121,7 @@ class Worker: These things are used to draw certain conclusions after the first session. Args: - benchmark_payload (dict): The payload used the benchmark. + benchmark_payload (dict): The payload used in the benchmark. Returns: dict: Worker info, including how it was benchmarked. @@ -137,9 +137,18 @@ class Worker: d[self.uuid] = data return d + def treat_mpe(self): + """ + In collecting percent errors to calculate the MPE, there may be influential outliers that skew the results. + Here we cull those outliers from them in order to get a more accurate end result. + """ + + # TODO implement this + pass + def eta_mpe(self): """ - Returns the mean absolute percent error using all the currently stored eta percent errors. + Returns the mean percent error using all the currently stored eta percent errors. Returns: mpe (float): The mean percent error of a worker's calculation estimates. @@ -165,7 +174,7 @@ class Worker: """ # TODO check if using http or https - return f"https://{self.__str__()}/sdapi/v1/{route}" + return f"http://{self.__str__()}/sdapi/v1/{route}" def batch_eta_hr(self, payload: dict) -> float: """ @@ -233,13 +242,21 @@ class Worker: if len(self.eta_percent_error) > 0: correction = eta * (self.eta_mpe() / 100) - if cmd_opts.distributed_debug: - print(f"worker '{self.uuid}'s last ETA was off by {correction}%") + # if abs(correction) > 300: + # print(f"correction {abs(correction)} exceeds 300%... .") + if cmd_opts.distributed_debug: + print(f"worker '{self.uuid}'s last ETA was off by {correction:.2f}%") + print(f"{self.uuid} eta before correction: ", eta) + + # do regression if correction > 0: - eta += correction - else: eta -= correction + else: + eta += correction + + if cmd_opts.distributed_debug: + print(f"{self.uuid} eta after correction: ", eta) return eta except Exception as e: @@ -308,10 +325,18 @@ class Worker: print(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n") print(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n") - if self.eta_percent_error == 0: - self.eta_percent_error[0] = variance + # if the variance is greater than 500% then we ignore it to prevent variation inflation + if abs(variance) < 500: + # check if there are already 5 samples and if so, remove the oldest + # this should help adjust to the user changing tasks + if len(self.eta_percent_error) > 4: + self.eta_percent_error.pop(0) + if self.eta_percent_error == 0: # init + self.eta_percent_error[0] = variance + else: # normal case + self.eta_percent_error.append(variance) else: - self.eta_percent_error.append(variance) + print(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n") except Exception as e: if payload['batch_size'] == 0: diff --git a/scripts/spartan/World.py b/scripts/spartan/World.py index f303fdc..9d67ed9 100644 --- a/scripts/spartan/World.py +++ b/scripts/spartan/World.py @@ -51,6 +51,14 @@ class Job: self.batch_size: int = batch_size self.complementary: bool = False + def __str__(self): + prefix = '' + suffix = f"Job: {self.batch_size} images. Owned by '{self.worker.uuid}'. Rate: {self.worker.avg_ipm}ipm" + if self.complementary: + prefix = "(complementary) " + + return prefix + suffix + class World: """ @@ -70,7 +78,7 @@ class World: self.total_batch_size: int = 0 self.workers: List[Worker] = [master_worker] self.jobs: List[Job] = [] - self.job_timeout: int = 10 # seconds + self.job_timeout: int = 0 # seconds self.initialized: bool = False self.verify_remotes = verify_remotes self.initial_payload = copy.copy(initial_payload) @@ -188,15 +196,19 @@ class World: t = Thread(target=worker.refresh_checkpoints, args=()) t.start() - def benchmark(self): + def benchmark(self, rebenchmark: bool = False): """ Attempts to benchmark all workers a part of the world. """ + global benchmark_payload workers_info: dict = {} saved: bool = os.path.exists(self.worker_info_path) benchmark_payload_loaded: bool = False + if rebenchmark: + saved = False + if saved: workers_info = json.load(open(self.worker_info_path, 'r')) @@ -250,6 +262,14 @@ class World: print(f"{i}. worker '{worker}' - {worker.avg_ipm} ipm") i += 1 + def __str__(self): + # print status of all jobs + jobs_str = "" + for job in self.jobs: + jobs_str += job.__str__() + "\n" + + return jobs_str + def realtime_jobs(self) -> List[Job]: """ Determines which jobs are considered real-time by checking which jobs are not(complementary). @@ -263,6 +283,8 @@ class World: if job.complementary is False: fast_jobs.append(job) + print(f"fast jobs: {fast_jobs}") + return fast_jobs def slowest_realtime_job(self) -> Job: @@ -292,6 +314,10 @@ class World: """ fastest_worker = self.fastest_realtime_job().worker + # if the worker is the fastest, then there is no lag + if worker == fastest_worker: + return 0 + lag = worker.batch_eta(payload=payload) - fastest_worker.batch_eta(payload=payload) return lag @@ -304,6 +330,7 @@ class World: Returns: float: Local worker speed in ipm """ + global benchmark_payload master_bench_payload = copy.copy(self.initial_payload) @@ -357,14 +384,14 @@ class World: deferred_images = 0 # the number of images that were not assigned to a worker due to the worker being too slow # the maximum amount of images that a "slow" worker can produce in the slack space where other nodes are working - max_compensation = 4 + # max_compensation = 4 currently unused images_per_job = None for job in self.jobs: lag = self.job_stall(job.worker, payload=payload) - if lag < self.job_timeout: + if lag < self.job_timeout or lag == 0: job.batch_size = payload['batch_size'] continue @@ -396,6 +423,9 @@ class World: slowest_active_worker = self.slowest_realtime_job().worker slack_time = slowest_active_worker.batch_eta(payload=payload) + if cmd_opts.distributed_debug: + print(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.uuid}'") + # in the case that this worker is now taking on what others workers would have been (if they were real-time) # this means that there will be more slack time for complementary nodes slack_time = slack_time + ((slack_time / payload['batch_size']) * images_per_job) @@ -412,6 +442,8 @@ class World: # It might be better to just inject a black image. (if master is that slow) master_job = self.master_job() if master_job.batch_size < 1: + if cmd_opts.distributed_debug: + print("Master couldn't keep up... defaulting to 1 image") master_job.batch_size = 1 print("After job optimization, job layout is the following:")