state validation
parent
ec7c4fe58a
commit
78035ac433
|
|
@ -306,7 +306,7 @@ class Worker:
|
|||
if waited >= (0.85 * max_wait):
|
||||
logger.warning("this seems long, so if you see this message often, consider reporting an issue")
|
||||
|
||||
self.state = State.WORKING
|
||||
self.set_state(State.WORKING)
|
||||
|
||||
# query memory available on worker and store for future reference
|
||||
if self.queried is False:
|
||||
|
|
@ -423,7 +423,6 @@ class Worker:
|
|||
sampler_name = payload.get('sampler_name', None)
|
||||
if sampler_index is None:
|
||||
if sampler_name is not None:
|
||||
logger.debug("had to substitute sampler index with name")
|
||||
payload['sampler_index'] = sampler_name
|
||||
|
||||
try:
|
||||
|
|
@ -490,7 +489,7 @@ class Worker:
|
|||
logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
|
||||
|
||||
except Exception as e:
|
||||
self.state = State.IDLE
|
||||
self.set_state(State.IDLE)
|
||||
|
||||
if payload['batch_size'] == 0:
|
||||
raise InvalidWorkerResponse("Tried to request a null amount of images")
|
||||
|
|
@ -498,10 +497,10 @@ class Worker:
|
|||
raise InvalidWorkerResponse(e)
|
||||
|
||||
except requests.RequestException:
|
||||
self.mark_unreachable()
|
||||
self.set_state(State.UNAVAILABLE)
|
||||
return
|
||||
|
||||
self.state = State.IDLE
|
||||
self.set_state(State.IDLE)
|
||||
self.jobs_requested += 1
|
||||
return
|
||||
|
||||
|
|
@ -539,7 +538,6 @@ class Worker:
|
|||
# this was due to something torch does at startup according to auto and is now done at sdwui startup
|
||||
for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up"
|
||||
if self.state == State.UNAVAILABLE:
|
||||
self.response = None
|
||||
return 0
|
||||
|
||||
try: # if the worker is unreachable/offline then handle that here
|
||||
|
|
@ -574,7 +572,7 @@ class Worker:
|
|||
self.response = None
|
||||
self.benchmarked = True
|
||||
self.eta_percent_error = [] # likely inaccurate after rebenching
|
||||
self.state = State.IDLE
|
||||
self.set_state(State.IDLE)
|
||||
return avg_ipm_result
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
|
|
@ -593,17 +591,17 @@ class Worker:
|
|||
logger.error(msg)
|
||||
# gradio.Warning("Distributed: "+msg)
|
||||
except requests.exceptions.ConnectionError:
|
||||
self.mark_unreachable()
|
||||
self.set_state(State.UNAVAILABLE)
|
||||
|
||||
def interrupt(self):
|
||||
try:
|
||||
response = self.session.post(self.full_url('interrupt'))
|
||||
|
||||
if response.status_code == 200:
|
||||
self.state = State.INTERRUPTED
|
||||
self.set_state(State.INTERRUPTED)
|
||||
logger.debug(f"successfully interrupted worker {self.label}")
|
||||
except requests.exceptions.ConnectionError:
|
||||
self.mark_unreachable()
|
||||
self.set_state(State.UNAVAILABLE)
|
||||
|
||||
def reachable(self) -> bool:
|
||||
"""returns false if worker is unreachable"""
|
||||
|
|
@ -622,18 +620,6 @@ class Worker:
|
|||
logger.error(e)
|
||||
return False
|
||||
|
||||
def mark_unreachable(self):
|
||||
if self.state == State.DISABLED:
|
||||
logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable")
|
||||
else:
|
||||
msg = f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection"
|
||||
logger.error(msg)
|
||||
# gradio.Warning("Distributed: "+msg)
|
||||
self.state = State.UNAVAILABLE
|
||||
# invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models
|
||||
self.loaded_model = None
|
||||
self.loaded_vae = None
|
||||
|
||||
def available_models(self) -> [List[str]]:
|
||||
if self.state == State.UNAVAILABLE or self.state == State.DISABLED or self.master:
|
||||
return []
|
||||
|
|
@ -654,7 +640,7 @@ class Worker:
|
|||
titles = [model['title'] for model in response.json()]
|
||||
return titles
|
||||
except requests.RequestException:
|
||||
self.mark_unreachable()
|
||||
self.set_state(State.UNAVAILABLE)
|
||||
return []
|
||||
|
||||
def load_options(self, model, vae=None):
|
||||
|
|
@ -671,14 +657,14 @@ class Worker:
|
|||
if vae is not None:
|
||||
payload['sd_vae'] = vae
|
||||
|
||||
self.state = State.WORKING
|
||||
self.set_state(State.WORKING)
|
||||
start = time.time()
|
||||
response = self.session.post(
|
||||
self.full_url("options"),
|
||||
json=payload
|
||||
)
|
||||
elapsed = time.time() - start
|
||||
self.state = State.IDLE
|
||||
self.set_state(State.IDLE)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.debug(f"failed to load options for worker '{self.label}'")
|
||||
|
|
@ -720,3 +706,44 @@ class Worker:
|
|||
|
||||
logger.error(f"{err_msg}: {response}")
|
||||
return False
|
||||
|
||||
def set_state(self, state: State):
|
||||
state_cache = self.state
|
||||
|
||||
def transition(ns: State):
|
||||
if ns == self.state:
|
||||
logger.critical(f"{self.label} was already {self.state.name}")
|
||||
return
|
||||
|
||||
logger.debug(f"{self.label}: {self.state.name} -> {ns.name}")
|
||||
self.state = ns
|
||||
|
||||
match self.state:
|
||||
case State.IDLE:
|
||||
if state in (State.IDLE, State.WORKING):
|
||||
transition(state)
|
||||
|
||||
case State.WORKING:
|
||||
if state in (State.WORKING, State.IDLE, State.INTERRUPTED):
|
||||
transition(state)
|
||||
|
||||
case State.UNAVAILABLE:
|
||||
if state in State.IDLE:
|
||||
transition(state)
|
||||
|
||||
case State.INTERRUPTED:
|
||||
if state in State.WORKING:
|
||||
transition(state)
|
||||
|
||||
if state == State.UNAVAILABLE:
|
||||
if self.state == State.DISABLED:
|
||||
logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable")
|
||||
else:
|
||||
logger.error(f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection")
|
||||
# invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models
|
||||
self.loaded_model = None
|
||||
self.loaded_vae = None
|
||||
transition(state)
|
||||
|
||||
if self.state == state_cache and self.state != state:
|
||||
logger.error(f"{self.label}: invalid transition {self.state.name} -> {state.name}")
|
||||
|
|
|
|||
|
|
@ -727,7 +727,7 @@ class World:
|
|||
msg = f"worker '{worker.label}' is online"
|
||||
logger.info(msg)
|
||||
gradio.Info("Distributed: "+msg)
|
||||
worker.state = State.IDLE
|
||||
worker.set_state(State.IDLE)
|
||||
else:
|
||||
msg = f"worker '{worker.label}' is unreachable"
|
||||
logger.info(msg)
|
||||
|
|
|
|||
Loading…
Reference in New Issue