From ec7c4fe58aa49fb899fcb2da38e3afa49014c4e5 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Sat, 25 May 2024 00:29:56 -0500 Subject: [PATCH 1/9] skip grid if only one image returned --- scripts/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/distributed.py b/scripts/distributed.py index 41422c1..ef1cc0b 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -197,7 +197,7 @@ class DistributedScript(scripts.Script): return # generate and inject grid - if opts.return_grid: + if opts.return_grid and len(processed.images) > 1: grid = image_grid(processed.images, len(processed.images)) processed_inject_image( image=grid, From 78035ac433f675ee1bdb21f1d8e9b290b14c8e3a Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Sun, 26 May 2024 17:47:06 -0500 Subject: [PATCH 2/9] state validation --- scripts/spartan/worker.py | 77 ++++++++++++++++++++++++++------------- scripts/spartan/world.py | 2 +- 2 files changed, 53 insertions(+), 26 deletions(-) diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index 832cdba..cc258cb 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -306,7 +306,7 @@ class Worker: if waited >= (0.85 * max_wait): logger.warning("this seems long, so if you see this message often, consider reporting an issue") - self.state = State.WORKING + self.set_state(State.WORKING) # query memory available on worker and store for future reference if self.queried is False: @@ -423,7 +423,6 @@ class Worker: sampler_name = payload.get('sampler_name', None) if sampler_index is None: if sampler_name is not None: - logger.debug("had to substitute sampler index with name") payload['sampler_index'] = sampler_name try: @@ -490,7 +489,7 @@ class Worker: logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n") except Exception as e: - self.state = State.IDLE + self.set_state(State.IDLE) if payload['batch_size'] == 0: raise InvalidWorkerResponse("Tried to request a null amount of images") @@ -498,10 +497,10 @@ class Worker: raise InvalidWorkerResponse(e) except requests.RequestException: - self.mark_unreachable() + self.set_state(State.UNAVAILABLE) return - self.state = State.IDLE + self.set_state(State.IDLE) self.jobs_requested += 1 return @@ -539,7 +538,6 @@ class Worker: # this was due to something torch does at startup according to auto and is now done at sdwui startup for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up" if self.state == State.UNAVAILABLE: - self.response = None return 0 try: # if the worker is unreachable/offline then handle that here @@ -574,7 +572,7 @@ class Worker: self.response = None self.benchmarked = True self.eta_percent_error = [] # likely inaccurate after rebenching - self.state = State.IDLE + self.set_state(State.IDLE) return avg_ipm_result def refresh_checkpoints(self): @@ -593,17 +591,17 @@ class Worker: logger.error(msg) # gradio.Warning("Distributed: "+msg) except requests.exceptions.ConnectionError: - self.mark_unreachable() + self.set_state(State.UNAVAILABLE) def interrupt(self): try: response = self.session.post(self.full_url('interrupt')) if response.status_code == 200: - self.state = State.INTERRUPTED + self.set_state(State.INTERRUPTED) logger.debug(f"successfully interrupted worker {self.label}") except requests.exceptions.ConnectionError: - self.mark_unreachable() + self.set_state(State.UNAVAILABLE) def reachable(self) -> bool: """returns false if worker is unreachable""" @@ -622,18 +620,6 @@ class Worker: logger.error(e) return False - def mark_unreachable(self): - if self.state == State.DISABLED: - logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable") - else: - msg = f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection" - logger.error(msg) - # gradio.Warning("Distributed: "+msg) - self.state = State.UNAVAILABLE - # invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models - self.loaded_model = None - self.loaded_vae = None - def available_models(self) -> [List[str]]: if self.state == State.UNAVAILABLE or self.state == State.DISABLED or self.master: return [] @@ -654,7 +640,7 @@ class Worker: titles = [model['title'] for model in response.json()] return titles except requests.RequestException: - self.mark_unreachable() + self.set_state(State.UNAVAILABLE) return [] def load_options(self, model, vae=None): @@ -671,14 +657,14 @@ class Worker: if vae is not None: payload['sd_vae'] = vae - self.state = State.WORKING + self.set_state(State.WORKING) start = time.time() response = self.session.post( self.full_url("options"), json=payload ) elapsed = time.time() - start - self.state = State.IDLE + self.set_state(State.IDLE) if response.status_code != 200: logger.debug(f"failed to load options for worker '{self.label}'") @@ -720,3 +706,44 @@ class Worker: logger.error(f"{err_msg}: {response}") return False + + def set_state(self, state: State): + state_cache = self.state + + def transition(ns: State): + if ns == self.state: + logger.critical(f"{self.label} was already {self.state.name}") + return + + logger.debug(f"{self.label}: {self.state.name} -> {ns.name}") + self.state = ns + + match self.state: + case State.IDLE: + if state in (State.IDLE, State.WORKING): + transition(state) + + case State.WORKING: + if state in (State.WORKING, State.IDLE, State.INTERRUPTED): + transition(state) + + case State.UNAVAILABLE: + if state in State.IDLE: + transition(state) + + case State.INTERRUPTED: + if state in State.WORKING: + transition(state) + + if state == State.UNAVAILABLE: + if self.state == State.DISABLED: + logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable") + else: + logger.error(f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection") + # invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models + self.loaded_model = None + self.loaded_vae = None + transition(state) + + if self.state == state_cache and self.state != state: + logger.error(f"{self.label}: invalid transition {self.state.name} -> {state.name}") diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index f555ca1..307868d 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -727,7 +727,7 @@ class World: msg = f"worker '{worker.label}' is online" logger.info(msg) gradio.Info("Distributed: "+msg) - worker.state = State.IDLE + worker.set_state(State.IDLE) else: msg = f"worker '{worker.label}' is unreachable" logger.info(msg) From 429de773f843afd5f88d85d85204aa24c82869bd Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Sun, 26 May 2024 18:35:53 -0500 Subject: [PATCH 3/9] make more concise --- scripts/spartan/worker.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index cc258cb..d55ae5d 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -718,22 +718,14 @@ class Worker: logger.debug(f"{self.label}: {self.state.name} -> {ns.name}") self.state = ns - match self.state: - case State.IDLE: - if state in (State.IDLE, State.WORKING): - transition(state) - - case State.WORKING: - if state in (State.WORKING, State.IDLE, State.INTERRUPTED): - transition(state) - - case State.UNAVAILABLE: - if state in State.IDLE: - transition(state) - - case State.INTERRUPTED: - if state in State.WORKING: - transition(state) + transitions = { + State.IDLE: {State.IDLE, State.WORKING}, + State.WORKING: {State.WORKING, State.IDLE, State.INTERRUPTED}, + State.UNAVAILABLE: {State.IDLE}, + State.INTERRUPTED: {State.WORKING}, + } + if state in transitions.get(self.state, {}): + transition(state) if state == State.UNAVAILABLE: if self.state == State.DISABLED: @@ -746,4 +738,4 @@ class Worker: transition(state) if self.state == state_cache and self.state != state: - logger.error(f"{self.label}: invalid transition {self.state.name} -> {state.name}") + logger.error(f"{self.label}: invalid or redundant transition {self.state.name} -> {state.name}") From 00d0e1a11a1e809d1bf41c92dd94165775f64c26 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Sun, 26 May 2024 21:16:53 -0500 Subject: [PATCH 4/9] logging --- scripts/spartan/worker.py | 4 ++-- scripts/spartan/world.py | 16 +++++++--------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index d55ae5d..b4ebbd8 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -712,7 +712,7 @@ class Worker: def transition(ns: State): if ns == self.state: - logger.critical(f"{self.label} was already {self.state.name}") + logger.debug(f"{self.label}: potentially redundant transition {self.state.name} -> {ns.name}") return logger.debug(f"{self.label}: {self.state.name} -> {ns.name}") @@ -738,4 +738,4 @@ class Worker: transition(state) if self.state == state_cache and self.state != state: - logger.error(f"{self.label}: invalid or redundant transition {self.state.name} -> {state.name}") + logger.debug(f"{self.label}: invalid transition {self.state.name} -> {state.name}") diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 307868d..fe06d07 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -160,13 +160,7 @@ class World: return new else: for key in kwargs: - if hasattr(original, key): - # TODO only necessary because this is skipping Worker.__init__ and the pyd model is saving the state as an int instead of an actual enum - if key == 'state': - original.state = kwargs[key] if type(kwargs[key]) is State else State(kwargs[key]) - continue - - setattr(original, key, kwargs[key]) + setattr(original, key, kwargs[key]) return original @@ -535,7 +529,7 @@ class World: seconds_per_sample = job.worker.eta(payload=payload, batch_size=1, samples=1) realtime_samples = slack_time // seconds_per_sample logger.debug( - f"job for '{job.worker.label}' downscaled to {realtime_samples} samples to meet time constraints\n" + f"job for '{job.worker.label}' downscaled to {realtime_samples:.0f} samples to meet time constraints\n" f"{realtime_samples:.0f} samples = {slack_time:.2f}s slack รท {seconds_per_sample:.2f}s/sample\n" f" step reduction: {payload['steps']} -> {realtime_samples:.0f}" ) @@ -654,6 +648,10 @@ class World: fields['label'] = label # TODO must be overridden everytime here or later converted to a config file variable at some point fields['verify_remotes'] = self.verify_remotes + # cast enum id to actual enum type and then prime state + fields['state'] = State(fields['state']) + if fields['state'] != State.DISABLED: + fields['state'] = State.IDLE self.add_worker(**fields) @@ -667,7 +665,7 @@ class World: def save_config(self): """ - Saves the config file. + Saves current state to the config file. """ config = ConfigModel( From 54a1f66c886144af25467246592ab6f05824837c Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Sun, 26 May 2024 21:51:35 -0500 Subject: [PATCH 5/9] add state machine utility parameter for clarity when debugging --- scripts/spartan/worker.py | 12 ++++++++++-- scripts/spartan/world.py | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index b4ebbd8..6f505fc 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -707,11 +707,19 @@ class Worker: logger.error(f"{err_msg}: {response}") return False - def set_state(self, state: State): + def set_state(self, state: State, expect_cycle: bool = False): + """ + Updates the state of a worker if considered a valid operation + + Args: + state: the new state to try transitioning to + expect_cycle: whether this transition might be a no-op/self-loop + + """ state_cache = self.state def transition(ns: State): - if ns == self.state: + if ns == self.state and expect_cycle is False: logger.debug(f"{self.label}: potentially redundant transition {self.state.name} -> {ns.name}") return diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index fe06d07..4b6c57a 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -725,7 +725,7 @@ class World: msg = f"worker '{worker.label}' is online" logger.info(msg) gradio.Info("Distributed: "+msg) - worker.set_state(State.IDLE) + worker.set_state(State.IDLE, expect_cycle=True) else: msg = f"worker '{worker.label}' is unreachable" logger.info(msg) From 27c0523ec650e130100b355a0e45d9de64407547 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Wed, 29 May 2024 21:34:58 -0500 Subject: [PATCH 6/9] prevent deadlock when rebenchmarking after reloading from config file at runtime --- scripts/spartan/world.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 4b6c57a..4ab5c24 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -18,8 +18,9 @@ from . import shared as sh from .pmodels import ConfigModel, Benchmark_Payload from .shared import logger, extension_path from .worker import Worker, State -from modules.call_queue import wrap_queued_call +from modules.call_queue import wrap_queued_call, queue_lock from modules import processing +from modules import progress class NotBenchmarked(Exception): @@ -180,7 +181,10 @@ class World: Thread(target=worker.refresh_checkpoints, args=()).start() def sample_master(self) -> float: - # wrap our benchmark payload + # progress.finish_task(progress.current_task) + if queue_lock._lock.locked(): + queue_lock.release() + master_bench_payload = StableDiffusionProcessingTxt2Img() d = sh.benchmark_payload.dict() for key in d: @@ -188,16 +192,17 @@ class World: # Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to. master_bench_payload.do_not_save_samples = True - # shared.state.begin(job='distributed_master_bench') - wrapped = (wrap_queued_call(process_images)) + shared.state.end() + wrapped = wrap_queued_call(process_images) start = time.time() wrapped(master_bench_payload) - # wrap_gradio_gpu_call(process_images)(master_bench_payload) - # shared.state.end() + # seems counter-intuitive but the lock will later be released again once the original task is ended by wui + # only doing things this way so that we can bench and then have an original user request immediately resume + if progress.current_task is not None: # could be no task, ie. running bench from utils tab + queue_lock.acquire() return time.time() - start - def benchmark(self, rebenchmark: bool = False): """ Attempts to benchmark all workers a part of the world. @@ -373,16 +378,16 @@ class World: def update(self, p): """preps world for another run""" - if not self.initialized: - self.benchmark() - self.initialized = True - logger.debug("world initialized!") - else: - logger.debug("world was already initialized") self.p = p + self.benchmark() self.make_jobs() + if not self.initialized: + self.initialized = True + logger.debug("world initialized!") + + def get_workers(self): filtered: List[Worker] = [] for worker in self._workers: @@ -661,7 +666,7 @@ class World: self.complement_production = config.complement_production self.step_scaling = config.step_scaling - logger.debug("config loaded") + logger.debug(f"config loaded from '{os.path.abspath(self.config_path)}'") def save_config(self): """ From 9cd7c7c3518104ce4ca08503d74a6a56ec6e6f60 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Sat, 1 Jun 2024 00:33:55 -0500 Subject: [PATCH 7/9] make benchmarking queueable --- scripts/spartan/ui.py | 6 ++++ scripts/spartan/world.py | 69 +++++++++++++++++++++------------------- 2 files changed, 43 insertions(+), 32 deletions(-) diff --git a/scripts/spartan/ui.py b/scripts/spartan/ui.py index cd0b87a..5aac350 100644 --- a/scripts/spartan/ui.py +++ b/scripts/spartan/ui.py @@ -7,6 +7,8 @@ from modules.shared import opts from modules.shared import state as webui_state from .shared import logger, LOG_LEVEL, gui_handler from .worker import State +from modules.call_queue import queue_lock +from modules import progress worker_select_dropdown = None @@ -61,6 +63,10 @@ class UI: """debug utility that will clear the internal webui queue. sometimes good for jams""" logger.debug(webui_state.__dict__) webui_state.end() + progress.pending_tasks.clear() + progress.current_task = None + if queue_lock._lock.locked(): + queue_lock.release() def status_btn(self): """updates a simplified overview of registered workers and their jobs""" diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 4ab5c24..0e3df16 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -181,26 +181,14 @@ class World: Thread(target=worker.refresh_checkpoints, args=()).start() def sample_master(self) -> float: - # progress.finish_task(progress.current_task) - if queue_lock._lock.locked(): - queue_lock.release() - - master_bench_payload = StableDiffusionProcessingTxt2Img() + p = StableDiffusionProcessingTxt2Img() d = sh.benchmark_payload.dict() for key in d: - setattr(master_bench_payload, key, d[key]) + setattr(p, key, d[key]) + p.do_not_save_samples = True - # Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to. - master_bench_payload.do_not_save_samples = True - shared.state.end() - wrapped = wrap_queued_call(process_images) start = time.time() - wrapped(master_bench_payload) - # seems counter-intuitive but the lock will later be released again once the original task is ended by wui - # only doing things this way so that we can bench and then have an original user request immediately resume - if progress.current_task is not None: # could be no task, ie. running bench from utils tab - queue_lock.acquire() - + process_images(p) return time.time() - start def benchmark(self, rebenchmark: bool = False): @@ -208,6 +196,7 @@ class World: Attempts to benchmark all workers a part of the world. """ + local_task_id = 'task(distributed_bench)' unbenched_workers = [] if rebenchmark: for worker in self._workers: @@ -246,26 +235,42 @@ class World: futures.clear() # benchmark those that haven't been - for worker in unbenched_workers: - if worker.state in (State.DISABLED, State.UNAVAILABLE): - logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark") - continue + if len(unbenched_workers) > 0: + queue_lock.acquire() + gradio.Info("Distributed: benchmarking in progress, please wait") + for worker in unbenched_workers: + if worker.state in (State.DISABLED, State.UNAVAILABLE): + logger.debug(f"worker '{worker.label}' is {worker.state.name}, refusing to benchmark") + continue - if worker.model_override is not None: - logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n" - f"*all workers should be evaluated against the same model") + if worker.model_override is not None: + logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n" + f"*all workers should be evaluated against the same model") - chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master) - futures.append(executor.submit(chosen, worker)) - logger.info(f"benchmarking worker '{worker.label}'") + if worker.master: + if progress.current_task is None: + progress.add_task_to_queue(local_task_id) + progress.start_task(local_task_id) + shared.state.begin(job=local_task_id) + shared.state.job_count = sh.warmup_samples + sh.samples - # wait for all benchmarks to finish and update stats on newly benchmarked workers - concurrent.futures.wait(futures) - logger.info("benchmarking finished") + chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master) + futures.append(executor.submit(chosen, worker)) + logger.info(f"benchmarking worker '{worker.label}'") - # save benchmark results to workers.json - self.save_config() - logger.info(self.speed_summary()) + if len(futures) > 0: + # wait for all benchmarks to finish and update stats on newly benchmarked workers + concurrent.futures.wait(futures) + + if progress.current_task == local_task_id: + shared.state.end() + progress.finish_task(local_task_id) + queue_lock.release() + + logger.info("benchmarking finished") + logger.info(self.speed_summary()) + gradio.Info("Distributed: benchmarking complete!") + self.save_config() def get_current_output_size(self) -> int: """ From bff6d16e42ee0d2e1dee18514e7979d724d6ad7a Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 30 Aug 2024 11:51:46 -0500 Subject: [PATCH 8/9] refactoring, state fix --- scripts/distributed.py | 26 +++++++++++++++----------- scripts/spartan/worker.py | 8 ++------ scripts/spartan/world.py | 38 +++++++++++++++++++++++--------------- 3 files changed, 40 insertions(+), 32 deletions(-) diff --git a/scripts/distributed.py b/scripts/distributed.py index ef1cc0b..492d4d4 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -24,7 +24,7 @@ from modules.shared import state as webui_state from scripts.spartan.control_net import pack_control_net from scripts.spartan.shared import logger from scripts.spartan.ui import UI -from scripts.spartan.world import World +from scripts.spartan.world import World, State old_sigint_handler = signal.getsignal(signal.SIGINT) old_sigterm_handler = signal.getsignal(signal.SIGTERM) @@ -33,7 +33,6 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM) # noinspection PyMissingOrEmptyDocstring class DistributedScript(scripts.Script): # global old_sigterm_handler, old_sigterm_handler - worker_threads: List[Thread] = [] # Whether to verify worker certificates. Can be useful if your remotes are self-signed. verify_remotes = not cmd_opts.distributed_skip_verify_remotes master_start = None @@ -150,17 +149,19 @@ class DistributedScript(scripts.Script): # wait for response from all workers webui_state.textinfo = "Distributed - receiving results" - for thread in self.worker_threads: - logger.debug(f"waiting for worker thread '{thread.name}'") - thread.join() - self.worker_threads.clear() + for job in self.world.jobs: + if job.thread is None: + continue + + logger.debug(f"waiting for worker thread '{job.thread.name}'") + job.thread.join() logger.debug("all worker request threads returned") webui_state.textinfo = "Distributed - injecting images" # some worker which we know has a good response that we can use for generating the grid donor_worker = None for job in self.world.jobs: - if job.batch_size < 1 or job.worker.master: + if job.worker.response is None or job.batch_size < 1 or job.worker.master: continue try: @@ -304,6 +305,9 @@ class DistributedScript(scripts.Script): return for job in self.world.jobs: + if job.worker.state in (State.UNAVAILABLE, State.DISABLED): + continue + payload_temp = copy.copy(payload) del payload_temp['scripts_value'] payload_temp = copy.deepcopy(payload_temp) @@ -332,11 +336,9 @@ class DistributedScript(scripts.Script): job.worker.loaded_model = name job.worker.loaded_vae = vae - t = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,), + job.thread = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,), name=f"{job.worker.label}_request") - - t.start() - self.worker_threads.append(t) + job.thread.start() started_jobs.append(job) # if master batch size was changed again due to optimization change it to the updated value @@ -358,6 +360,8 @@ class DistributedScript(scripts.Script): # restore process_images_inner if it was monkey-patched processing.process_images_inner = self.original_process_images_inner + # save any dangling state to prevent load_config in next iteration overwriting it + self.world.save_config() @staticmethod def signal_handler(sig, frame): diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index 6f505fc..f4eed0a 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -344,7 +344,7 @@ class Worker: # remove anything that is not serializable # s_tmax can be float('inf') which is not serializable, so we convert it to the max float value s_tmax = payload.get('s_tmax', 0.0) - if s_tmax > 1e308: + if s_tmax is not None and s_tmax > 1e308: payload['s_tmax'] = 1e308 # remove unserializable caches payload.pop('cached_uc', None) @@ -490,11 +490,7 @@ class Worker: except Exception as e: self.set_state(State.IDLE) - - if payload['batch_size'] == 0: - raise InvalidWorkerResponse("Tried to request a null amount of images") - else: - raise InvalidWorkerResponse(e) + raise InvalidWorkerResponse(e) except requests.RequestException: self.set_state(State.UNAVAILABLE) diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 0e3df16..f08faf3 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -45,6 +45,7 @@ class Job: self.batch_size: int = batch_size self.complementary: bool = False self.step_override = None + self.thread = None def __str__(self): prefix = '' @@ -373,7 +374,7 @@ class World: batch_size = self.default_batch_size() for worker in self.get_workers(): - if worker.state != State.DISABLED and worker.state != State.UNAVAILABLE: + if worker.state not in (State.DISABLED, State.UNAVAILABLE): if worker.avg_ipm is None or worker.avg_ipm <= 0: logger.debug(f"No recorded speed for worker '{worker.label}, benchmarking'") worker.benchmark() @@ -401,7 +402,7 @@ class World: continue if worker.master and self.thin_client_mode: continue - if worker.state != State.UNAVAILABLE and worker.state != State.DISABLED: + if worker.state not in (State.UNAVAILABLE, State.DISABLED): filtered.append(worker) return filtered @@ -550,17 +551,7 @@ class World: else: logger.debug("complementary image production is disabled") - iterations = payload['n_iter'] - num_returning = self.get_current_output_size() - num_complementary = num_returning - self.p.batch_size - distro_summary = "Job distribution:\n" - distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)" - if num_complementary > 0: - distro_summary += f" + {num_complementary} complementary" - distro_summary += f": {num_returning} images total\n" - for job in self.jobs: - distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n" - logger.info(distro_summary) + logger.info(self.distro_summary(payload)) if self.thin_client_mode is True or self.master_job().batch_size == 0: # save original process_images_inner for later so we can restore once we're done @@ -585,6 +576,20 @@ class World: del self.jobs[last] last -= 1 + def distro_summary(self, payload): + # iterations = dict(payload)['n_iter'] + iterations = self.p.n_iter + num_returning = self.get_current_output_size() + num_complementary = num_returning - self.p.batch_size + distro_summary = "Job distribution:\n" + distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)" + if num_complementary > 0: + distro_summary += f" + {num_complementary} complementary" + distro_summary += f": {num_returning} images total\n" + for job in self.jobs: + distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n" + return distro_summary + def config(self) -> dict: """ { @@ -660,8 +665,8 @@ class World: fields['verify_remotes'] = self.verify_remotes # cast enum id to actual enum type and then prime state fields['state'] = State(fields['state']) - if fields['state'] != State.DISABLED: - fields['state'] = State.IDLE + if fields['state'] not in (State.DISABLED, State.UNAVAILABLE): + fields['state'] = State.IDLE self.add_worker(**fields) @@ -740,6 +745,9 @@ class World: msg = f"worker '{worker.label}' is unreachable" logger.info(msg) gradio.Warning("Distributed: "+msg) + worker.set_state(State.UNAVAILABLE) + + self.save_config() def restart_all(self): for worker in self._workers: From afc3ec8ba3ded2627dd3454b8a0fc487a8b2d71f Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 30 Aug 2024 12:00:52 -0500 Subject: [PATCH 9/9] update changelog --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fab4b45..83f3d94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,12 @@ # Change Log Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html) +## [2.2.2] - 2024-8-30 + +### Fixed + +- Unavailable state sometimes being ignored + ## [2.2.1] - 2024-5-16 ### Fixed