state validation
parent
ec7c4fe58a
commit
78035ac433
|
|
@ -306,7 +306,7 @@ class Worker:
|
||||||
if waited >= (0.85 * max_wait):
|
if waited >= (0.85 * max_wait):
|
||||||
logger.warning("this seems long, so if you see this message often, consider reporting an issue")
|
logger.warning("this seems long, so if you see this message often, consider reporting an issue")
|
||||||
|
|
||||||
self.state = State.WORKING
|
self.set_state(State.WORKING)
|
||||||
|
|
||||||
# query memory available on worker and store for future reference
|
# query memory available on worker and store for future reference
|
||||||
if self.queried is False:
|
if self.queried is False:
|
||||||
|
|
@ -423,7 +423,6 @@ class Worker:
|
||||||
sampler_name = payload.get('sampler_name', None)
|
sampler_name = payload.get('sampler_name', None)
|
||||||
if sampler_index is None:
|
if sampler_index is None:
|
||||||
if sampler_name is not None:
|
if sampler_name is not None:
|
||||||
logger.debug("had to substitute sampler index with name")
|
|
||||||
payload['sampler_index'] = sampler_name
|
payload['sampler_index'] = sampler_name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -490,7 +489,7 @@ class Worker:
|
||||||
logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
|
logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.state = State.IDLE
|
self.set_state(State.IDLE)
|
||||||
|
|
||||||
if payload['batch_size'] == 0:
|
if payload['batch_size'] == 0:
|
||||||
raise InvalidWorkerResponse("Tried to request a null amount of images")
|
raise InvalidWorkerResponse("Tried to request a null amount of images")
|
||||||
|
|
@ -498,10 +497,10 @@ class Worker:
|
||||||
raise InvalidWorkerResponse(e)
|
raise InvalidWorkerResponse(e)
|
||||||
|
|
||||||
except requests.RequestException:
|
except requests.RequestException:
|
||||||
self.mark_unreachable()
|
self.set_state(State.UNAVAILABLE)
|
||||||
return
|
return
|
||||||
|
|
||||||
self.state = State.IDLE
|
self.set_state(State.IDLE)
|
||||||
self.jobs_requested += 1
|
self.jobs_requested += 1
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -539,7 +538,6 @@ class Worker:
|
||||||
# this was due to something torch does at startup according to auto and is now done at sdwui startup
|
# this was due to something torch does at startup according to auto and is now done at sdwui startup
|
||||||
for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up"
|
for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up"
|
||||||
if self.state == State.UNAVAILABLE:
|
if self.state == State.UNAVAILABLE:
|
||||||
self.response = None
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
try: # if the worker is unreachable/offline then handle that here
|
try: # if the worker is unreachable/offline then handle that here
|
||||||
|
|
@ -574,7 +572,7 @@ class Worker:
|
||||||
self.response = None
|
self.response = None
|
||||||
self.benchmarked = True
|
self.benchmarked = True
|
||||||
self.eta_percent_error = [] # likely inaccurate after rebenching
|
self.eta_percent_error = [] # likely inaccurate after rebenching
|
||||||
self.state = State.IDLE
|
self.set_state(State.IDLE)
|
||||||
return avg_ipm_result
|
return avg_ipm_result
|
||||||
|
|
||||||
def refresh_checkpoints(self):
|
def refresh_checkpoints(self):
|
||||||
|
|
@ -593,17 +591,17 @@ class Worker:
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
# gradio.Warning("Distributed: "+msg)
|
# gradio.Warning("Distributed: "+msg)
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
self.mark_unreachable()
|
self.set_state(State.UNAVAILABLE)
|
||||||
|
|
||||||
def interrupt(self):
|
def interrupt(self):
|
||||||
try:
|
try:
|
||||||
response = self.session.post(self.full_url('interrupt'))
|
response = self.session.post(self.full_url('interrupt'))
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
self.state = State.INTERRUPTED
|
self.set_state(State.INTERRUPTED)
|
||||||
logger.debug(f"successfully interrupted worker {self.label}")
|
logger.debug(f"successfully interrupted worker {self.label}")
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
self.mark_unreachable()
|
self.set_state(State.UNAVAILABLE)
|
||||||
|
|
||||||
def reachable(self) -> bool:
|
def reachable(self) -> bool:
|
||||||
"""returns false if worker is unreachable"""
|
"""returns false if worker is unreachable"""
|
||||||
|
|
@ -622,18 +620,6 @@ class Worker:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def mark_unreachable(self):
|
|
||||||
if self.state == State.DISABLED:
|
|
||||||
logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable")
|
|
||||||
else:
|
|
||||||
msg = f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection"
|
|
||||||
logger.error(msg)
|
|
||||||
# gradio.Warning("Distributed: "+msg)
|
|
||||||
self.state = State.UNAVAILABLE
|
|
||||||
# invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models
|
|
||||||
self.loaded_model = None
|
|
||||||
self.loaded_vae = None
|
|
||||||
|
|
||||||
def available_models(self) -> [List[str]]:
|
def available_models(self) -> [List[str]]:
|
||||||
if self.state == State.UNAVAILABLE or self.state == State.DISABLED or self.master:
|
if self.state == State.UNAVAILABLE or self.state == State.DISABLED or self.master:
|
||||||
return []
|
return []
|
||||||
|
|
@ -654,7 +640,7 @@ class Worker:
|
||||||
titles = [model['title'] for model in response.json()]
|
titles = [model['title'] for model in response.json()]
|
||||||
return titles
|
return titles
|
||||||
except requests.RequestException:
|
except requests.RequestException:
|
||||||
self.mark_unreachable()
|
self.set_state(State.UNAVAILABLE)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def load_options(self, model, vae=None):
|
def load_options(self, model, vae=None):
|
||||||
|
|
@ -671,14 +657,14 @@ class Worker:
|
||||||
if vae is not None:
|
if vae is not None:
|
||||||
payload['sd_vae'] = vae
|
payload['sd_vae'] = vae
|
||||||
|
|
||||||
self.state = State.WORKING
|
self.set_state(State.WORKING)
|
||||||
start = time.time()
|
start = time.time()
|
||||||
response = self.session.post(
|
response = self.session.post(
|
||||||
self.full_url("options"),
|
self.full_url("options"),
|
||||||
json=payload
|
json=payload
|
||||||
)
|
)
|
||||||
elapsed = time.time() - start
|
elapsed = time.time() - start
|
||||||
self.state = State.IDLE
|
self.set_state(State.IDLE)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
logger.debug(f"failed to load options for worker '{self.label}'")
|
logger.debug(f"failed to load options for worker '{self.label}'")
|
||||||
|
|
@ -720,3 +706,44 @@ class Worker:
|
||||||
|
|
||||||
logger.error(f"{err_msg}: {response}")
|
logger.error(f"{err_msg}: {response}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def set_state(self, state: State):
|
||||||
|
state_cache = self.state
|
||||||
|
|
||||||
|
def transition(ns: State):
|
||||||
|
if ns == self.state:
|
||||||
|
logger.critical(f"{self.label} was already {self.state.name}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"{self.label}: {self.state.name} -> {ns.name}")
|
||||||
|
self.state = ns
|
||||||
|
|
||||||
|
match self.state:
|
||||||
|
case State.IDLE:
|
||||||
|
if state in (State.IDLE, State.WORKING):
|
||||||
|
transition(state)
|
||||||
|
|
||||||
|
case State.WORKING:
|
||||||
|
if state in (State.WORKING, State.IDLE, State.INTERRUPTED):
|
||||||
|
transition(state)
|
||||||
|
|
||||||
|
case State.UNAVAILABLE:
|
||||||
|
if state in State.IDLE:
|
||||||
|
transition(state)
|
||||||
|
|
||||||
|
case State.INTERRUPTED:
|
||||||
|
if state in State.WORKING:
|
||||||
|
transition(state)
|
||||||
|
|
||||||
|
if state == State.UNAVAILABLE:
|
||||||
|
if self.state == State.DISABLED:
|
||||||
|
logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable")
|
||||||
|
else:
|
||||||
|
logger.error(f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection")
|
||||||
|
# invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models
|
||||||
|
self.loaded_model = None
|
||||||
|
self.loaded_vae = None
|
||||||
|
transition(state)
|
||||||
|
|
||||||
|
if self.state == state_cache and self.state != state:
|
||||||
|
logger.error(f"{self.label}: invalid transition {self.state.name} -> {state.name}")
|
||||||
|
|
|
||||||
|
|
@ -727,7 +727,7 @@ class World:
|
||||||
msg = f"worker '{worker.label}' is online"
|
msg = f"worker '{worker.label}' is online"
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
gradio.Info("Distributed: "+msg)
|
gradio.Info("Distributed: "+msg)
|
||||||
worker.state = State.IDLE
|
worker.set_state(State.IDLE)
|
||||||
else:
|
else:
|
||||||
msg = f"worker '{worker.label}' is unreachable"
|
msg = f"worker '{worker.label}' is unreachable"
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue