state validation

master
papuSpartan 2024-05-26 17:47:06 -05:00
parent ec7c4fe58a
commit 78035ac433
2 changed files with 53 additions and 26 deletions

View File

@ -306,7 +306,7 @@ class Worker:
if waited >= (0.85 * max_wait): if waited >= (0.85 * max_wait):
logger.warning("this seems long, so if you see this message often, consider reporting an issue") logger.warning("this seems long, so if you see this message often, consider reporting an issue")
self.state = State.WORKING self.set_state(State.WORKING)
# query memory available on worker and store for future reference # query memory available on worker and store for future reference
if self.queried is False: if self.queried is False:
@ -423,7 +423,6 @@ class Worker:
sampler_name = payload.get('sampler_name', None) sampler_name = payload.get('sampler_name', None)
if sampler_index is None: if sampler_index is None:
if sampler_name is not None: if sampler_name is not None:
logger.debug("had to substitute sampler index with name")
payload['sampler_index'] = sampler_name payload['sampler_index'] = sampler_name
try: try:
@ -490,7 +489,7 @@ class Worker:
logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n") logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
except Exception as e: except Exception as e:
self.state = State.IDLE self.set_state(State.IDLE)
if payload['batch_size'] == 0: if payload['batch_size'] == 0:
raise InvalidWorkerResponse("Tried to request a null amount of images") raise InvalidWorkerResponse("Tried to request a null amount of images")
@ -498,10 +497,10 @@ class Worker:
raise InvalidWorkerResponse(e) raise InvalidWorkerResponse(e)
except requests.RequestException: except requests.RequestException:
self.mark_unreachable() self.set_state(State.UNAVAILABLE)
return return
self.state = State.IDLE self.set_state(State.IDLE)
self.jobs_requested += 1 self.jobs_requested += 1
return return
@ -539,7 +538,6 @@ class Worker:
# this was due to something torch does at startup according to auto and is now done at sdwui startup # this was due to something torch does at startup according to auto and is now done at sdwui startup
for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up" for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up"
if self.state == State.UNAVAILABLE: if self.state == State.UNAVAILABLE:
self.response = None
return 0 return 0
try: # if the worker is unreachable/offline then handle that here try: # if the worker is unreachable/offline then handle that here
@ -574,7 +572,7 @@ class Worker:
self.response = None self.response = None
self.benchmarked = True self.benchmarked = True
self.eta_percent_error = [] # likely inaccurate after rebenching self.eta_percent_error = [] # likely inaccurate after rebenching
self.state = State.IDLE self.set_state(State.IDLE)
return avg_ipm_result return avg_ipm_result
def refresh_checkpoints(self): def refresh_checkpoints(self):
@ -593,17 +591,17 @@ class Worker:
logger.error(msg) logger.error(msg)
# gradio.Warning("Distributed: "+msg) # gradio.Warning("Distributed: "+msg)
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
self.mark_unreachable() self.set_state(State.UNAVAILABLE)
def interrupt(self): def interrupt(self):
try: try:
response = self.session.post(self.full_url('interrupt')) response = self.session.post(self.full_url('interrupt'))
if response.status_code == 200: if response.status_code == 200:
self.state = State.INTERRUPTED self.set_state(State.INTERRUPTED)
logger.debug(f"successfully interrupted worker {self.label}") logger.debug(f"successfully interrupted worker {self.label}")
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
self.mark_unreachable() self.set_state(State.UNAVAILABLE)
def reachable(self) -> bool: def reachable(self) -> bool:
"""returns false if worker is unreachable""" """returns false if worker is unreachable"""
@ -622,18 +620,6 @@ class Worker:
logger.error(e) logger.error(e)
return False return False
def mark_unreachable(self):
if self.state == State.DISABLED:
logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable")
else:
msg = f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection"
logger.error(msg)
# gradio.Warning("Distributed: "+msg)
self.state = State.UNAVAILABLE
# invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models
self.loaded_model = None
self.loaded_vae = None
def available_models(self) -> [List[str]]: def available_models(self) -> [List[str]]:
if self.state == State.UNAVAILABLE or self.state == State.DISABLED or self.master: if self.state == State.UNAVAILABLE or self.state == State.DISABLED or self.master:
return [] return []
@ -654,7 +640,7 @@ class Worker:
titles = [model['title'] for model in response.json()] titles = [model['title'] for model in response.json()]
return titles return titles
except requests.RequestException: except requests.RequestException:
self.mark_unreachable() self.set_state(State.UNAVAILABLE)
return [] return []
def load_options(self, model, vae=None): def load_options(self, model, vae=None):
@ -671,14 +657,14 @@ class Worker:
if vae is not None: if vae is not None:
payload['sd_vae'] = vae payload['sd_vae'] = vae
self.state = State.WORKING self.set_state(State.WORKING)
start = time.time() start = time.time()
response = self.session.post( response = self.session.post(
self.full_url("options"), self.full_url("options"),
json=payload json=payload
) )
elapsed = time.time() - start elapsed = time.time() - start
self.state = State.IDLE self.set_state(State.IDLE)
if response.status_code != 200: if response.status_code != 200:
logger.debug(f"failed to load options for worker '{self.label}'") logger.debug(f"failed to load options for worker '{self.label}'")
@ -720,3 +706,44 @@ class Worker:
logger.error(f"{err_msg}: {response}") logger.error(f"{err_msg}: {response}")
return False return False
def set_state(self, state: State):
state_cache = self.state
def transition(ns: State):
if ns == self.state:
logger.critical(f"{self.label} was already {self.state.name}")
return
logger.debug(f"{self.label}: {self.state.name} -> {ns.name}")
self.state = ns
match self.state:
case State.IDLE:
if state in (State.IDLE, State.WORKING):
transition(state)
case State.WORKING:
if state in (State.WORKING, State.IDLE, State.INTERRUPTED):
transition(state)
case State.UNAVAILABLE:
if state in State.IDLE:
transition(state)
case State.INTERRUPTED:
if state in State.WORKING:
transition(state)
if state == State.UNAVAILABLE:
if self.state == State.DISABLED:
logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable")
else:
logger.error(f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection")
# invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models
self.loaded_model = None
self.loaded_vae = None
transition(state)
if self.state == state_cache and self.state != state:
logger.error(f"{self.label}: invalid transition {self.state.name} -> {state.name}")

View File

@ -727,7 +727,7 @@ class World:
msg = f"worker '{worker.label}' is online" msg = f"worker '{worker.label}' is online"
logger.info(msg) logger.info(msg)
gradio.Info("Distributed: "+msg) gradio.Info("Distributed: "+msg)
worker.state = State.IDLE worker.set_state(State.IDLE)
else: else:
msg = f"worker '{worker.label}' is unreachable" msg = f"worker '{worker.label}' is unreachable"
logger.info(msg) logger.info(msg)