diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index 832cdba..cc258cb 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -306,7 +306,7 @@ class Worker: if waited >= (0.85 * max_wait): logger.warning("this seems long, so if you see this message often, consider reporting an issue") - self.state = State.WORKING + self.set_state(State.WORKING) # query memory available on worker and store for future reference if self.queried is False: @@ -423,7 +423,6 @@ class Worker: sampler_name = payload.get('sampler_name', None) if sampler_index is None: if sampler_name is not None: - logger.debug("had to substitute sampler index with name") payload['sampler_index'] = sampler_name try: @@ -490,7 +489,7 @@ class Worker: logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n") except Exception as e: - self.state = State.IDLE + self.set_state(State.IDLE) if payload['batch_size'] == 0: raise InvalidWorkerResponse("Tried to request a null amount of images") @@ -498,10 +497,10 @@ class Worker: raise InvalidWorkerResponse(e) except requests.RequestException: - self.mark_unreachable() + self.set_state(State.UNAVAILABLE) return - self.state = State.IDLE + self.set_state(State.IDLE) self.jobs_requested += 1 return @@ -539,7 +538,6 @@ class Worker: # this was due to something torch does at startup according to auto and is now done at sdwui startup for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up" if self.state == State.UNAVAILABLE: - self.response = None return 0 try: # if the worker is unreachable/offline then handle that here @@ -574,7 +572,7 @@ class Worker: self.response = None self.benchmarked = True self.eta_percent_error = [] # likely inaccurate after rebenching - self.state = State.IDLE + self.set_state(State.IDLE) return avg_ipm_result def refresh_checkpoints(self): @@ -593,17 +591,17 @@ class Worker: logger.error(msg) # gradio.Warning("Distributed: "+msg) except requests.exceptions.ConnectionError: - self.mark_unreachable() + self.set_state(State.UNAVAILABLE) def interrupt(self): try: response = self.session.post(self.full_url('interrupt')) if response.status_code == 200: - self.state = State.INTERRUPTED + self.set_state(State.INTERRUPTED) logger.debug(f"successfully interrupted worker {self.label}") except requests.exceptions.ConnectionError: - self.mark_unreachable() + self.set_state(State.UNAVAILABLE) def reachable(self) -> bool: """returns false if worker is unreachable""" @@ -622,18 +620,6 @@ class Worker: logger.error(e) return False - def mark_unreachable(self): - if self.state == State.DISABLED: - logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable") - else: - msg = f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection" - logger.error(msg) - # gradio.Warning("Distributed: "+msg) - self.state = State.UNAVAILABLE - # invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models - self.loaded_model = None - self.loaded_vae = None - def available_models(self) -> [List[str]]: if self.state == State.UNAVAILABLE or self.state == State.DISABLED or self.master: return [] @@ -654,7 +640,7 @@ class Worker: titles = [model['title'] for model in response.json()] return titles except requests.RequestException: - self.mark_unreachable() + self.set_state(State.UNAVAILABLE) return [] def load_options(self, model, vae=None): @@ -671,14 +657,14 @@ class Worker: if vae is not None: payload['sd_vae'] = vae - self.state = State.WORKING + self.set_state(State.WORKING) start = time.time() response = self.session.post( self.full_url("options"), json=payload ) elapsed = time.time() - start - self.state = State.IDLE + self.set_state(State.IDLE) if response.status_code != 200: logger.debug(f"failed to load options for worker '{self.label}'") @@ -720,3 +706,44 @@ class Worker: logger.error(f"{err_msg}: {response}") return False + + def set_state(self, state: State): + state_cache = self.state + + def transition(ns: State): + if ns == self.state: + logger.critical(f"{self.label} was already {self.state.name}") + return + + logger.debug(f"{self.label}: {self.state.name} -> {ns.name}") + self.state = ns + + match self.state: + case State.IDLE: + if state in (State.IDLE, State.WORKING): + transition(state) + + case State.WORKING: + if state in (State.WORKING, State.IDLE, State.INTERRUPTED): + transition(state) + + case State.UNAVAILABLE: + if state in State.IDLE: + transition(state) + + case State.INTERRUPTED: + if state in State.WORKING: + transition(state) + + if state == State.UNAVAILABLE: + if self.state == State.DISABLED: + logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable") + else: + logger.error(f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection") + # invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models + self.loaded_model = None + self.loaded_vae = None + transition(state) + + if self.state == state_cache and self.state != state: + logger.error(f"{self.label}: invalid transition {self.state.name} -> {state.name}") diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index f555ca1..307868d 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -727,7 +727,7 @@ class World: msg = f"worker '{worker.label}' is online" logger.info(msg) gradio.Info("Distributed: "+msg) - worker.state = State.IDLE + worker.set_state(State.IDLE) else: msg = f"worker '{worker.label}' is unreachable" logger.info(msg)