diff --git a/scripts/distributed.py b/scripts/distributed.py index ef1cc0b..492d4d4 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -24,7 +24,7 @@ from modules.shared import state as webui_state from scripts.spartan.control_net import pack_control_net from scripts.spartan.shared import logger from scripts.spartan.ui import UI -from scripts.spartan.world import World +from scripts.spartan.world import World, State old_sigint_handler = signal.getsignal(signal.SIGINT) old_sigterm_handler = signal.getsignal(signal.SIGTERM) @@ -33,7 +33,6 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM) # noinspection PyMissingOrEmptyDocstring class DistributedScript(scripts.Script): # global old_sigterm_handler, old_sigterm_handler - worker_threads: List[Thread] = [] # Whether to verify worker certificates. Can be useful if your remotes are self-signed. verify_remotes = not cmd_opts.distributed_skip_verify_remotes master_start = None @@ -150,17 +149,19 @@ class DistributedScript(scripts.Script): # wait for response from all workers webui_state.textinfo = "Distributed - receiving results" - for thread in self.worker_threads: - logger.debug(f"waiting for worker thread '{thread.name}'") - thread.join() - self.worker_threads.clear() + for job in self.world.jobs: + if job.thread is None: + continue + + logger.debug(f"waiting for worker thread '{job.thread.name}'") + job.thread.join() logger.debug("all worker request threads returned") webui_state.textinfo = "Distributed - injecting images" # some worker which we know has a good response that we can use for generating the grid donor_worker = None for job in self.world.jobs: - if job.batch_size < 1 or job.worker.master: + if job.worker.response is None or job.batch_size < 1 or job.worker.master: continue try: @@ -304,6 +305,9 @@ class DistributedScript(scripts.Script): return for job in self.world.jobs: + if job.worker.state in (State.UNAVAILABLE, State.DISABLED): + continue + payload_temp = copy.copy(payload) del payload_temp['scripts_value'] payload_temp = copy.deepcopy(payload_temp) @@ -332,11 +336,9 @@ class DistributedScript(scripts.Script): job.worker.loaded_model = name job.worker.loaded_vae = vae - t = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,), + job.thread = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,), name=f"{job.worker.label}_request") - - t.start() - self.worker_threads.append(t) + job.thread.start() started_jobs.append(job) # if master batch size was changed again due to optimization change it to the updated value @@ -358,6 +360,8 @@ class DistributedScript(scripts.Script): # restore process_images_inner if it was monkey-patched processing.process_images_inner = self.original_process_images_inner + # save any dangling state to prevent load_config in next iteration overwriting it + self.world.save_config() @staticmethod def signal_handler(sig, frame): diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index 6f505fc..f4eed0a 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -344,7 +344,7 @@ class Worker: # remove anything that is not serializable # s_tmax can be float('inf') which is not serializable, so we convert it to the max float value s_tmax = payload.get('s_tmax', 0.0) - if s_tmax > 1e308: + if s_tmax is not None and s_tmax > 1e308: payload['s_tmax'] = 1e308 # remove unserializable caches payload.pop('cached_uc', None) @@ -490,11 +490,7 @@ class Worker: except Exception as e: self.set_state(State.IDLE) - - if payload['batch_size'] == 0: - raise InvalidWorkerResponse("Tried to request a null amount of images") - else: - raise InvalidWorkerResponse(e) + raise InvalidWorkerResponse(e) except requests.RequestException: self.set_state(State.UNAVAILABLE) diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 0e3df16..f08faf3 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -45,6 +45,7 @@ class Job: self.batch_size: int = batch_size self.complementary: bool = False self.step_override = None + self.thread = None def __str__(self): prefix = '' @@ -373,7 +374,7 @@ class World: batch_size = self.default_batch_size() for worker in self.get_workers(): - if worker.state != State.DISABLED and worker.state != State.UNAVAILABLE: + if worker.state not in (State.DISABLED, State.UNAVAILABLE): if worker.avg_ipm is None or worker.avg_ipm <= 0: logger.debug(f"No recorded speed for worker '{worker.label}, benchmarking'") worker.benchmark() @@ -401,7 +402,7 @@ class World: continue if worker.master and self.thin_client_mode: continue - if worker.state != State.UNAVAILABLE and worker.state != State.DISABLED: + if worker.state not in (State.UNAVAILABLE, State.DISABLED): filtered.append(worker) return filtered @@ -550,17 +551,7 @@ class World: else: logger.debug("complementary image production is disabled") - iterations = payload['n_iter'] - num_returning = self.get_current_output_size() - num_complementary = num_returning - self.p.batch_size - distro_summary = "Job distribution:\n" - distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)" - if num_complementary > 0: - distro_summary += f" + {num_complementary} complementary" - distro_summary += f": {num_returning} images total\n" - for job in self.jobs: - distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n" - logger.info(distro_summary) + logger.info(self.distro_summary(payload)) if self.thin_client_mode is True or self.master_job().batch_size == 0: # save original process_images_inner for later so we can restore once we're done @@ -585,6 +576,20 @@ class World: del self.jobs[last] last -= 1 + def distro_summary(self, payload): + # iterations = dict(payload)['n_iter'] + iterations = self.p.n_iter + num_returning = self.get_current_output_size() + num_complementary = num_returning - self.p.batch_size + distro_summary = "Job distribution:\n" + distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)" + if num_complementary > 0: + distro_summary += f" + {num_complementary} complementary" + distro_summary += f": {num_returning} images total\n" + for job in self.jobs: + distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n" + return distro_summary + def config(self) -> dict: """ { @@ -660,8 +665,8 @@ class World: fields['verify_remotes'] = self.verify_remotes # cast enum id to actual enum type and then prime state fields['state'] = State(fields['state']) - if fields['state'] != State.DISABLED: - fields['state'] = State.IDLE + if fields['state'] not in (State.DISABLED, State.UNAVAILABLE): + fields['state'] = State.IDLE self.add_worker(**fields) @@ -740,6 +745,9 @@ class World: msg = f"worker '{worker.label}' is unreachable" logger.info(msg) gradio.Warning("Distributed: "+msg) + worker.set_state(State.UNAVAILABLE) + + self.save_config() def restart_all(self): for worker in self._workers: