commit
65b0eb75e0
|
|
@ -1,6 +1,12 @@
|
|||
# Change Log
|
||||
Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html)
|
||||
|
||||
## [2.2.2] - 2024-8-30
|
||||
|
||||
### Fixed
|
||||
|
||||
- Unavailable state sometimes being ignored
|
||||
|
||||
## [2.2.1] - 2024-5-16
|
||||
|
||||
### Fixed
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from modules.shared import state as webui_state
|
|||
from scripts.spartan.control_net import pack_control_net
|
||||
from scripts.spartan.shared import logger
|
||||
from scripts.spartan.ui import UI
|
||||
from scripts.spartan.world import World
|
||||
from scripts.spartan.world import World, State
|
||||
|
||||
old_sigint_handler = signal.getsignal(signal.SIGINT)
|
||||
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
||||
|
|
@ -33,7 +33,6 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
|||
# noinspection PyMissingOrEmptyDocstring
|
||||
class DistributedScript(scripts.Script):
|
||||
# global old_sigterm_handler, old_sigterm_handler
|
||||
worker_threads: List[Thread] = []
|
||||
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
|
||||
verify_remotes = not cmd_opts.distributed_skip_verify_remotes
|
||||
master_start = None
|
||||
|
|
@ -150,17 +149,19 @@ class DistributedScript(scripts.Script):
|
|||
|
||||
# wait for response from all workers
|
||||
webui_state.textinfo = "Distributed - receiving results"
|
||||
for thread in self.worker_threads:
|
||||
logger.debug(f"waiting for worker thread '{thread.name}'")
|
||||
thread.join()
|
||||
self.worker_threads.clear()
|
||||
for job in self.world.jobs:
|
||||
if job.thread is None:
|
||||
continue
|
||||
|
||||
logger.debug(f"waiting for worker thread '{job.thread.name}'")
|
||||
job.thread.join()
|
||||
logger.debug("all worker request threads returned")
|
||||
webui_state.textinfo = "Distributed - injecting images"
|
||||
|
||||
# some worker which we know has a good response that we can use for generating the grid
|
||||
donor_worker = None
|
||||
for job in self.world.jobs:
|
||||
if job.batch_size < 1 or job.worker.master:
|
||||
if job.worker.response is None or job.batch_size < 1 or job.worker.master:
|
||||
continue
|
||||
|
||||
try:
|
||||
|
|
@ -197,7 +198,7 @@ class DistributedScript(scripts.Script):
|
|||
return
|
||||
|
||||
# generate and inject grid
|
||||
if opts.return_grid:
|
||||
if opts.return_grid and len(processed.images) > 1:
|
||||
grid = image_grid(processed.images, len(processed.images))
|
||||
processed_inject_image(
|
||||
image=grid,
|
||||
|
|
@ -304,6 +305,9 @@ class DistributedScript(scripts.Script):
|
|||
return
|
||||
|
||||
for job in self.world.jobs:
|
||||
if job.worker.state in (State.UNAVAILABLE, State.DISABLED):
|
||||
continue
|
||||
|
||||
payload_temp = copy.copy(payload)
|
||||
del payload_temp['scripts_value']
|
||||
payload_temp = copy.deepcopy(payload_temp)
|
||||
|
|
@ -332,11 +336,9 @@ class DistributedScript(scripts.Script):
|
|||
job.worker.loaded_model = name
|
||||
job.worker.loaded_vae = vae
|
||||
|
||||
t = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,),
|
||||
job.thread = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,),
|
||||
name=f"{job.worker.label}_request")
|
||||
|
||||
t.start()
|
||||
self.worker_threads.append(t)
|
||||
job.thread.start()
|
||||
started_jobs.append(job)
|
||||
|
||||
# if master batch size was changed again due to optimization change it to the updated value
|
||||
|
|
@ -358,6 +360,8 @@ class DistributedScript(scripts.Script):
|
|||
|
||||
# restore process_images_inner if it was monkey-patched
|
||||
processing.process_images_inner = self.original_process_images_inner
|
||||
# save any dangling state to prevent load_config in next iteration overwriting it
|
||||
self.world.save_config()
|
||||
|
||||
@staticmethod
|
||||
def signal_handler(sig, frame):
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from modules.shared import opts
|
|||
from modules.shared import state as webui_state
|
||||
from .shared import logger, LOG_LEVEL, gui_handler
|
||||
from .worker import State
|
||||
from modules.call_queue import queue_lock
|
||||
from modules import progress
|
||||
|
||||
worker_select_dropdown = None
|
||||
|
||||
|
|
@ -61,6 +63,10 @@ class UI:
|
|||
"""debug utility that will clear the internal webui queue. sometimes good for jams"""
|
||||
logger.debug(webui_state.__dict__)
|
||||
webui_state.end()
|
||||
progress.pending_tasks.clear()
|
||||
progress.current_task = None
|
||||
if queue_lock._lock.locked():
|
||||
queue_lock.release()
|
||||
|
||||
def status_btn(self):
|
||||
"""updates a simplified overview of registered workers and their jobs"""
|
||||
|
|
|
|||
|
|
@ -306,7 +306,7 @@ class Worker:
|
|||
if waited >= (0.85 * max_wait):
|
||||
logger.warning("this seems long, so if you see this message often, consider reporting an issue")
|
||||
|
||||
self.state = State.WORKING
|
||||
self.set_state(State.WORKING)
|
||||
|
||||
# query memory available on worker and store for future reference
|
||||
if self.queried is False:
|
||||
|
|
@ -344,7 +344,7 @@ class Worker:
|
|||
# remove anything that is not serializable
|
||||
# s_tmax can be float('inf') which is not serializable, so we convert it to the max float value
|
||||
s_tmax = payload.get('s_tmax', 0.0)
|
||||
if s_tmax > 1e308:
|
||||
if s_tmax is not None and s_tmax > 1e308:
|
||||
payload['s_tmax'] = 1e308
|
||||
# remove unserializable caches
|
||||
payload.pop('cached_uc', None)
|
||||
|
|
@ -423,7 +423,6 @@ class Worker:
|
|||
sampler_name = payload.get('sampler_name', None)
|
||||
if sampler_index is None:
|
||||
if sampler_name is not None:
|
||||
logger.debug("had to substitute sampler index with name")
|
||||
payload['sampler_index'] = sampler_name
|
||||
|
||||
try:
|
||||
|
|
@ -490,18 +489,14 @@ class Worker:
|
|||
logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
|
||||
|
||||
except Exception as e:
|
||||
self.state = State.IDLE
|
||||
|
||||
if payload['batch_size'] == 0:
|
||||
raise InvalidWorkerResponse("Tried to request a null amount of images")
|
||||
else:
|
||||
raise InvalidWorkerResponse(e)
|
||||
self.set_state(State.IDLE)
|
||||
raise InvalidWorkerResponse(e)
|
||||
|
||||
except requests.RequestException:
|
||||
self.mark_unreachable()
|
||||
self.set_state(State.UNAVAILABLE)
|
||||
return
|
||||
|
||||
self.state = State.IDLE
|
||||
self.set_state(State.IDLE)
|
||||
self.jobs_requested += 1
|
||||
return
|
||||
|
||||
|
|
@ -539,7 +534,6 @@ class Worker:
|
|||
# this was due to something torch does at startup according to auto and is now done at sdwui startup
|
||||
for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up"
|
||||
if self.state == State.UNAVAILABLE:
|
||||
self.response = None
|
||||
return 0
|
||||
|
||||
try: # if the worker is unreachable/offline then handle that here
|
||||
|
|
@ -574,7 +568,7 @@ class Worker:
|
|||
self.response = None
|
||||
self.benchmarked = True
|
||||
self.eta_percent_error = [] # likely inaccurate after rebenching
|
||||
self.state = State.IDLE
|
||||
self.set_state(State.IDLE)
|
||||
return avg_ipm_result
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
|
|
@ -593,17 +587,17 @@ class Worker:
|
|||
logger.error(msg)
|
||||
# gradio.Warning("Distributed: "+msg)
|
||||
except requests.exceptions.ConnectionError:
|
||||
self.mark_unreachable()
|
||||
self.set_state(State.UNAVAILABLE)
|
||||
|
||||
def interrupt(self):
|
||||
try:
|
||||
response = self.session.post(self.full_url('interrupt'))
|
||||
|
||||
if response.status_code == 200:
|
||||
self.state = State.INTERRUPTED
|
||||
self.set_state(State.INTERRUPTED)
|
||||
logger.debug(f"successfully interrupted worker {self.label}")
|
||||
except requests.exceptions.ConnectionError:
|
||||
self.mark_unreachable()
|
||||
self.set_state(State.UNAVAILABLE)
|
||||
|
||||
def reachable(self) -> bool:
|
||||
"""returns false if worker is unreachable"""
|
||||
|
|
@ -622,18 +616,6 @@ class Worker:
|
|||
logger.error(e)
|
||||
return False
|
||||
|
||||
def mark_unreachable(self):
|
||||
if self.state == State.DISABLED:
|
||||
logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable")
|
||||
else:
|
||||
msg = f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection"
|
||||
logger.error(msg)
|
||||
# gradio.Warning("Distributed: "+msg)
|
||||
self.state = State.UNAVAILABLE
|
||||
# invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models
|
||||
self.loaded_model = None
|
||||
self.loaded_vae = None
|
||||
|
||||
def available_models(self) -> [List[str]]:
|
||||
if self.state == State.UNAVAILABLE or self.state == State.DISABLED or self.master:
|
||||
return []
|
||||
|
|
@ -654,7 +636,7 @@ class Worker:
|
|||
titles = [model['title'] for model in response.json()]
|
||||
return titles
|
||||
except requests.RequestException:
|
||||
self.mark_unreachable()
|
||||
self.set_state(State.UNAVAILABLE)
|
||||
return []
|
||||
|
||||
def load_options(self, model, vae=None):
|
||||
|
|
@ -671,14 +653,14 @@ class Worker:
|
|||
if vae is not None:
|
||||
payload['sd_vae'] = vae
|
||||
|
||||
self.state = State.WORKING
|
||||
self.set_state(State.WORKING)
|
||||
start = time.time()
|
||||
response = self.session.post(
|
||||
self.full_url("options"),
|
||||
json=payload
|
||||
)
|
||||
elapsed = time.time() - start
|
||||
self.state = State.IDLE
|
||||
self.set_state(State.IDLE)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.debug(f"failed to load options for worker '{self.label}'")
|
||||
|
|
@ -720,3 +702,44 @@ class Worker:
|
|||
|
||||
logger.error(f"{err_msg}: {response}")
|
||||
return False
|
||||
|
||||
def set_state(self, state: State, expect_cycle: bool = False):
|
||||
"""
|
||||
Updates the state of a worker if considered a valid operation
|
||||
|
||||
Args:
|
||||
state: the new state to try transitioning to
|
||||
expect_cycle: whether this transition might be a no-op/self-loop
|
||||
|
||||
"""
|
||||
state_cache = self.state
|
||||
|
||||
def transition(ns: State):
|
||||
if ns == self.state and expect_cycle is False:
|
||||
logger.debug(f"{self.label}: potentially redundant transition {self.state.name} -> {ns.name}")
|
||||
return
|
||||
|
||||
logger.debug(f"{self.label}: {self.state.name} -> {ns.name}")
|
||||
self.state = ns
|
||||
|
||||
transitions = {
|
||||
State.IDLE: {State.IDLE, State.WORKING},
|
||||
State.WORKING: {State.WORKING, State.IDLE, State.INTERRUPTED},
|
||||
State.UNAVAILABLE: {State.IDLE},
|
||||
State.INTERRUPTED: {State.WORKING},
|
||||
}
|
||||
if state in transitions.get(self.state, {}):
|
||||
transition(state)
|
||||
|
||||
if state == State.UNAVAILABLE:
|
||||
if self.state == State.DISABLED:
|
||||
logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable")
|
||||
else:
|
||||
logger.error(f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection")
|
||||
# invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models
|
||||
self.loaded_model = None
|
||||
self.loaded_vae = None
|
||||
transition(state)
|
||||
|
||||
if self.state == state_cache and self.state != state:
|
||||
logger.debug(f"{self.label}: invalid transition {self.state.name} -> {state.name}")
|
||||
|
|
|
|||
|
|
@ -18,8 +18,9 @@ from . import shared as sh
|
|||
from .pmodels import ConfigModel, Benchmark_Payload
|
||||
from .shared import logger, extension_path
|
||||
from .worker import Worker, State
|
||||
from modules.call_queue import wrap_queued_call
|
||||
from modules.call_queue import wrap_queued_call, queue_lock
|
||||
from modules import processing
|
||||
from modules import progress
|
||||
|
||||
|
||||
class NotBenchmarked(Exception):
|
||||
|
|
@ -44,6 +45,7 @@ class Job:
|
|||
self.batch_size: int = batch_size
|
||||
self.complementary: bool = False
|
||||
self.step_override = None
|
||||
self.thread = None
|
||||
|
||||
def __str__(self):
|
||||
prefix = ''
|
||||
|
|
@ -160,13 +162,7 @@ class World:
|
|||
return new
|
||||
else:
|
||||
for key in kwargs:
|
||||
if hasattr(original, key):
|
||||
# TODO only necessary because this is skipping Worker.__init__ and the pyd model is saving the state as an int instead of an actual enum
|
||||
if key == 'state':
|
||||
original.state = kwargs[key] if type(kwargs[key]) is State else State(kwargs[key])
|
||||
continue
|
||||
|
||||
setattr(original, key, kwargs[key])
|
||||
setattr(original, key, kwargs[key])
|
||||
|
||||
return original
|
||||
|
||||
|
|
@ -186,29 +182,22 @@ class World:
|
|||
Thread(target=worker.refresh_checkpoints, args=()).start()
|
||||
|
||||
def sample_master(self) -> float:
|
||||
# wrap our benchmark payload
|
||||
master_bench_payload = StableDiffusionProcessingTxt2Img()
|
||||
p = StableDiffusionProcessingTxt2Img()
|
||||
d = sh.benchmark_payload.dict()
|
||||
for key in d:
|
||||
setattr(master_bench_payload, key, d[key])
|
||||
setattr(p, key, d[key])
|
||||
p.do_not_save_samples = True
|
||||
|
||||
# Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to.
|
||||
master_bench_payload.do_not_save_samples = True
|
||||
# shared.state.begin(job='distributed_master_bench')
|
||||
wrapped = (wrap_queued_call(process_images))
|
||||
start = time.time()
|
||||
wrapped(master_bench_payload)
|
||||
# wrap_gradio_gpu_call(process_images)(master_bench_payload)
|
||||
# shared.state.end()
|
||||
|
||||
process_images(p)
|
||||
return time.time() - start
|
||||
|
||||
|
||||
def benchmark(self, rebenchmark: bool = False):
|
||||
"""
|
||||
Attempts to benchmark all workers a part of the world.
|
||||
"""
|
||||
|
||||
local_task_id = 'task(distributed_bench)'
|
||||
unbenched_workers = []
|
||||
if rebenchmark:
|
||||
for worker in self._workers:
|
||||
|
|
@ -247,26 +236,42 @@ class World:
|
|||
futures.clear()
|
||||
|
||||
# benchmark those that haven't been
|
||||
for worker in unbenched_workers:
|
||||
if worker.state in (State.DISABLED, State.UNAVAILABLE):
|
||||
logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark")
|
||||
continue
|
||||
if len(unbenched_workers) > 0:
|
||||
queue_lock.acquire()
|
||||
gradio.Info("Distributed: benchmarking in progress, please wait")
|
||||
for worker in unbenched_workers:
|
||||
if worker.state in (State.DISABLED, State.UNAVAILABLE):
|
||||
logger.debug(f"worker '{worker.label}' is {worker.state.name}, refusing to benchmark")
|
||||
continue
|
||||
|
||||
if worker.model_override is not None:
|
||||
logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n"
|
||||
f"*all workers should be evaluated against the same model")
|
||||
if worker.model_override is not None:
|
||||
logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n"
|
||||
f"*all workers should be evaluated against the same model")
|
||||
|
||||
chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master)
|
||||
futures.append(executor.submit(chosen, worker))
|
||||
logger.info(f"benchmarking worker '{worker.label}'")
|
||||
if worker.master:
|
||||
if progress.current_task is None:
|
||||
progress.add_task_to_queue(local_task_id)
|
||||
progress.start_task(local_task_id)
|
||||
shared.state.begin(job=local_task_id)
|
||||
shared.state.job_count = sh.warmup_samples + sh.samples
|
||||
|
||||
# wait for all benchmarks to finish and update stats on newly benchmarked workers
|
||||
concurrent.futures.wait(futures)
|
||||
logger.info("benchmarking finished")
|
||||
chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master)
|
||||
futures.append(executor.submit(chosen, worker))
|
||||
logger.info(f"benchmarking worker '{worker.label}'")
|
||||
|
||||
# save benchmark results to workers.json
|
||||
self.save_config()
|
||||
logger.info(self.speed_summary())
|
||||
if len(futures) > 0:
|
||||
# wait for all benchmarks to finish and update stats on newly benchmarked workers
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
if progress.current_task == local_task_id:
|
||||
shared.state.end()
|
||||
progress.finish_task(local_task_id)
|
||||
queue_lock.release()
|
||||
|
||||
logger.info("benchmarking finished")
|
||||
logger.info(self.speed_summary())
|
||||
gradio.Info("Distributed: benchmarking complete!")
|
||||
self.save_config()
|
||||
|
||||
def get_current_output_size(self) -> int:
|
||||
"""
|
||||
|
|
@ -369,7 +374,7 @@ class World:
|
|||
|
||||
batch_size = self.default_batch_size()
|
||||
for worker in self.get_workers():
|
||||
if worker.state != State.DISABLED and worker.state != State.UNAVAILABLE:
|
||||
if worker.state not in (State.DISABLED, State.UNAVAILABLE):
|
||||
if worker.avg_ipm is None or worker.avg_ipm <= 0:
|
||||
logger.debug(f"No recorded speed for worker '{worker.label}, benchmarking'")
|
||||
worker.benchmark()
|
||||
|
|
@ -379,16 +384,16 @@ class World:
|
|||
|
||||
def update(self, p):
|
||||
"""preps world for another run"""
|
||||
if not self.initialized:
|
||||
self.benchmark()
|
||||
self.initialized = True
|
||||
logger.debug("world initialized!")
|
||||
else:
|
||||
logger.debug("world was already initialized")
|
||||
|
||||
self.p = p
|
||||
self.benchmark()
|
||||
self.make_jobs()
|
||||
|
||||
if not self.initialized:
|
||||
self.initialized = True
|
||||
logger.debug("world initialized!")
|
||||
|
||||
|
||||
def get_workers(self):
|
||||
filtered: List[Worker] = []
|
||||
for worker in self._workers:
|
||||
|
|
@ -397,7 +402,7 @@ class World:
|
|||
continue
|
||||
if worker.master and self.thin_client_mode:
|
||||
continue
|
||||
if worker.state != State.UNAVAILABLE and worker.state != State.DISABLED:
|
||||
if worker.state not in (State.UNAVAILABLE, State.DISABLED):
|
||||
filtered.append(worker)
|
||||
|
||||
return filtered
|
||||
|
|
@ -535,7 +540,7 @@ class World:
|
|||
seconds_per_sample = job.worker.eta(payload=payload, batch_size=1, samples=1)
|
||||
realtime_samples = slack_time // seconds_per_sample
|
||||
logger.debug(
|
||||
f"job for '{job.worker.label}' downscaled to {realtime_samples} samples to meet time constraints\n"
|
||||
f"job for '{job.worker.label}' downscaled to {realtime_samples:.0f} samples to meet time constraints\n"
|
||||
f"{realtime_samples:.0f} samples = {slack_time:.2f}s slack ÷ {seconds_per_sample:.2f}s/sample\n"
|
||||
f" step reduction: {payload['steps']} -> {realtime_samples:.0f}"
|
||||
)
|
||||
|
|
@ -546,17 +551,7 @@ class World:
|
|||
else:
|
||||
logger.debug("complementary image production is disabled")
|
||||
|
||||
iterations = payload['n_iter']
|
||||
num_returning = self.get_current_output_size()
|
||||
num_complementary = num_returning - self.p.batch_size
|
||||
distro_summary = "Job distribution:\n"
|
||||
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
|
||||
if num_complementary > 0:
|
||||
distro_summary += f" + {num_complementary} complementary"
|
||||
distro_summary += f": {num_returning} images total\n"
|
||||
for job in self.jobs:
|
||||
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
|
||||
logger.info(distro_summary)
|
||||
logger.info(self.distro_summary(payload))
|
||||
|
||||
if self.thin_client_mode is True or self.master_job().batch_size == 0:
|
||||
# save original process_images_inner for later so we can restore once we're done
|
||||
|
|
@ -581,6 +576,20 @@ class World:
|
|||
del self.jobs[last]
|
||||
last -= 1
|
||||
|
||||
def distro_summary(self, payload):
|
||||
# iterations = dict(payload)['n_iter']
|
||||
iterations = self.p.n_iter
|
||||
num_returning = self.get_current_output_size()
|
||||
num_complementary = num_returning - self.p.batch_size
|
||||
distro_summary = "Job distribution:\n"
|
||||
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
|
||||
if num_complementary > 0:
|
||||
distro_summary += f" + {num_complementary} complementary"
|
||||
distro_summary += f": {num_returning} images total\n"
|
||||
for job in self.jobs:
|
||||
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
|
||||
return distro_summary
|
||||
|
||||
def config(self) -> dict:
|
||||
"""
|
||||
{
|
||||
|
|
@ -654,6 +663,10 @@ class World:
|
|||
fields['label'] = label
|
||||
# TODO must be overridden everytime here or later converted to a config file variable at some point
|
||||
fields['verify_remotes'] = self.verify_remotes
|
||||
# cast enum id to actual enum type and then prime state
|
||||
fields['state'] = State(fields['state'])
|
||||
if fields['state'] not in (State.DISABLED, State.UNAVAILABLE):
|
||||
fields['state'] = State.IDLE
|
||||
|
||||
self.add_worker(**fields)
|
||||
|
||||
|
|
@ -663,11 +676,11 @@ class World:
|
|||
self.complement_production = config.complement_production
|
||||
self.step_scaling = config.step_scaling
|
||||
|
||||
logger.debug("config loaded")
|
||||
logger.debug(f"config loaded from '{os.path.abspath(self.config_path)}'")
|
||||
|
||||
def save_config(self):
|
||||
"""
|
||||
Saves the config file.
|
||||
Saves current state to the config file.
|
||||
"""
|
||||
|
||||
config = ConfigModel(
|
||||
|
|
@ -727,11 +740,14 @@ class World:
|
|||
msg = f"worker '{worker.label}' is online"
|
||||
logger.info(msg)
|
||||
gradio.Info("Distributed: "+msg)
|
||||
worker.state = State.IDLE
|
||||
worker.set_state(State.IDLE, expect_cycle=True)
|
||||
else:
|
||||
msg = f"worker '{worker.label}' is unreachable"
|
||||
logger.info(msg)
|
||||
gradio.Warning("Distributed: "+msg)
|
||||
worker.set_state(State.UNAVAILABLE)
|
||||
|
||||
self.save_config()
|
||||
|
||||
def restart_all(self):
|
||||
for worker in self._workers:
|
||||
|
|
|
|||
Loading…
Reference in New Issue