commit
65b0eb75e0
|
|
@ -1,6 +1,12 @@
|
||||||
# Change Log
|
# Change Log
|
||||||
Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html)
|
Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html)
|
||||||
|
|
||||||
|
## [2.2.2] - 2024-8-30
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Unavailable state sometimes being ignored
|
||||||
|
|
||||||
## [2.2.1] - 2024-5-16
|
## [2.2.1] - 2024-5-16
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ from modules.shared import state as webui_state
|
||||||
from scripts.spartan.control_net import pack_control_net
|
from scripts.spartan.control_net import pack_control_net
|
||||||
from scripts.spartan.shared import logger
|
from scripts.spartan.shared import logger
|
||||||
from scripts.spartan.ui import UI
|
from scripts.spartan.ui import UI
|
||||||
from scripts.spartan.world import World
|
from scripts.spartan.world import World, State
|
||||||
|
|
||||||
old_sigint_handler = signal.getsignal(signal.SIGINT)
|
old_sigint_handler = signal.getsignal(signal.SIGINT)
|
||||||
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
||||||
|
|
@ -33,7 +33,6 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
||||||
# noinspection PyMissingOrEmptyDocstring
|
# noinspection PyMissingOrEmptyDocstring
|
||||||
class DistributedScript(scripts.Script):
|
class DistributedScript(scripts.Script):
|
||||||
# global old_sigterm_handler, old_sigterm_handler
|
# global old_sigterm_handler, old_sigterm_handler
|
||||||
worker_threads: List[Thread] = []
|
|
||||||
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
|
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
|
||||||
verify_remotes = not cmd_opts.distributed_skip_verify_remotes
|
verify_remotes = not cmd_opts.distributed_skip_verify_remotes
|
||||||
master_start = None
|
master_start = None
|
||||||
|
|
@ -150,17 +149,19 @@ class DistributedScript(scripts.Script):
|
||||||
|
|
||||||
# wait for response from all workers
|
# wait for response from all workers
|
||||||
webui_state.textinfo = "Distributed - receiving results"
|
webui_state.textinfo = "Distributed - receiving results"
|
||||||
for thread in self.worker_threads:
|
for job in self.world.jobs:
|
||||||
logger.debug(f"waiting for worker thread '{thread.name}'")
|
if job.thread is None:
|
||||||
thread.join()
|
continue
|
||||||
self.worker_threads.clear()
|
|
||||||
|
logger.debug(f"waiting for worker thread '{job.thread.name}'")
|
||||||
|
job.thread.join()
|
||||||
logger.debug("all worker request threads returned")
|
logger.debug("all worker request threads returned")
|
||||||
webui_state.textinfo = "Distributed - injecting images"
|
webui_state.textinfo = "Distributed - injecting images"
|
||||||
|
|
||||||
# some worker which we know has a good response that we can use for generating the grid
|
# some worker which we know has a good response that we can use for generating the grid
|
||||||
donor_worker = None
|
donor_worker = None
|
||||||
for job in self.world.jobs:
|
for job in self.world.jobs:
|
||||||
if job.batch_size < 1 or job.worker.master:
|
if job.worker.response is None or job.batch_size < 1 or job.worker.master:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -197,7 +198,7 @@ class DistributedScript(scripts.Script):
|
||||||
return
|
return
|
||||||
|
|
||||||
# generate and inject grid
|
# generate and inject grid
|
||||||
if opts.return_grid:
|
if opts.return_grid and len(processed.images) > 1:
|
||||||
grid = image_grid(processed.images, len(processed.images))
|
grid = image_grid(processed.images, len(processed.images))
|
||||||
processed_inject_image(
|
processed_inject_image(
|
||||||
image=grid,
|
image=grid,
|
||||||
|
|
@ -304,6 +305,9 @@ class DistributedScript(scripts.Script):
|
||||||
return
|
return
|
||||||
|
|
||||||
for job in self.world.jobs:
|
for job in self.world.jobs:
|
||||||
|
if job.worker.state in (State.UNAVAILABLE, State.DISABLED):
|
||||||
|
continue
|
||||||
|
|
||||||
payload_temp = copy.copy(payload)
|
payload_temp = copy.copy(payload)
|
||||||
del payload_temp['scripts_value']
|
del payload_temp['scripts_value']
|
||||||
payload_temp = copy.deepcopy(payload_temp)
|
payload_temp = copy.deepcopy(payload_temp)
|
||||||
|
|
@ -332,11 +336,9 @@ class DistributedScript(scripts.Script):
|
||||||
job.worker.loaded_model = name
|
job.worker.loaded_model = name
|
||||||
job.worker.loaded_vae = vae
|
job.worker.loaded_vae = vae
|
||||||
|
|
||||||
t = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,),
|
job.thread = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,),
|
||||||
name=f"{job.worker.label}_request")
|
name=f"{job.worker.label}_request")
|
||||||
|
job.thread.start()
|
||||||
t.start()
|
|
||||||
self.worker_threads.append(t)
|
|
||||||
started_jobs.append(job)
|
started_jobs.append(job)
|
||||||
|
|
||||||
# if master batch size was changed again due to optimization change it to the updated value
|
# if master batch size was changed again due to optimization change it to the updated value
|
||||||
|
|
@ -358,6 +360,8 @@ class DistributedScript(scripts.Script):
|
||||||
|
|
||||||
# restore process_images_inner if it was monkey-patched
|
# restore process_images_inner if it was monkey-patched
|
||||||
processing.process_images_inner = self.original_process_images_inner
|
processing.process_images_inner = self.original_process_images_inner
|
||||||
|
# save any dangling state to prevent load_config in next iteration overwriting it
|
||||||
|
self.world.save_config()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def signal_handler(sig, frame):
|
def signal_handler(sig, frame):
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ from modules.shared import opts
|
||||||
from modules.shared import state as webui_state
|
from modules.shared import state as webui_state
|
||||||
from .shared import logger, LOG_LEVEL, gui_handler
|
from .shared import logger, LOG_LEVEL, gui_handler
|
||||||
from .worker import State
|
from .worker import State
|
||||||
|
from modules.call_queue import queue_lock
|
||||||
|
from modules import progress
|
||||||
|
|
||||||
worker_select_dropdown = None
|
worker_select_dropdown = None
|
||||||
|
|
||||||
|
|
@ -61,6 +63,10 @@ class UI:
|
||||||
"""debug utility that will clear the internal webui queue. sometimes good for jams"""
|
"""debug utility that will clear the internal webui queue. sometimes good for jams"""
|
||||||
logger.debug(webui_state.__dict__)
|
logger.debug(webui_state.__dict__)
|
||||||
webui_state.end()
|
webui_state.end()
|
||||||
|
progress.pending_tasks.clear()
|
||||||
|
progress.current_task = None
|
||||||
|
if queue_lock._lock.locked():
|
||||||
|
queue_lock.release()
|
||||||
|
|
||||||
def status_btn(self):
|
def status_btn(self):
|
||||||
"""updates a simplified overview of registered workers and their jobs"""
|
"""updates a simplified overview of registered workers and their jobs"""
|
||||||
|
|
|
||||||
|
|
@ -306,7 +306,7 @@ class Worker:
|
||||||
if waited >= (0.85 * max_wait):
|
if waited >= (0.85 * max_wait):
|
||||||
logger.warning("this seems long, so if you see this message often, consider reporting an issue")
|
logger.warning("this seems long, so if you see this message often, consider reporting an issue")
|
||||||
|
|
||||||
self.state = State.WORKING
|
self.set_state(State.WORKING)
|
||||||
|
|
||||||
# query memory available on worker and store for future reference
|
# query memory available on worker and store for future reference
|
||||||
if self.queried is False:
|
if self.queried is False:
|
||||||
|
|
@ -344,7 +344,7 @@ class Worker:
|
||||||
# remove anything that is not serializable
|
# remove anything that is not serializable
|
||||||
# s_tmax can be float('inf') which is not serializable, so we convert it to the max float value
|
# s_tmax can be float('inf') which is not serializable, so we convert it to the max float value
|
||||||
s_tmax = payload.get('s_tmax', 0.0)
|
s_tmax = payload.get('s_tmax', 0.0)
|
||||||
if s_tmax > 1e308:
|
if s_tmax is not None and s_tmax > 1e308:
|
||||||
payload['s_tmax'] = 1e308
|
payload['s_tmax'] = 1e308
|
||||||
# remove unserializable caches
|
# remove unserializable caches
|
||||||
payload.pop('cached_uc', None)
|
payload.pop('cached_uc', None)
|
||||||
|
|
@ -423,7 +423,6 @@ class Worker:
|
||||||
sampler_name = payload.get('sampler_name', None)
|
sampler_name = payload.get('sampler_name', None)
|
||||||
if sampler_index is None:
|
if sampler_index is None:
|
||||||
if sampler_name is not None:
|
if sampler_name is not None:
|
||||||
logger.debug("had to substitute sampler index with name")
|
|
||||||
payload['sampler_index'] = sampler_name
|
payload['sampler_index'] = sampler_name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -490,18 +489,14 @@ class Worker:
|
||||||
logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
|
logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.state = State.IDLE
|
self.set_state(State.IDLE)
|
||||||
|
raise InvalidWorkerResponse(e)
|
||||||
if payload['batch_size'] == 0:
|
|
||||||
raise InvalidWorkerResponse("Tried to request a null amount of images")
|
|
||||||
else:
|
|
||||||
raise InvalidWorkerResponse(e)
|
|
||||||
|
|
||||||
except requests.RequestException:
|
except requests.RequestException:
|
||||||
self.mark_unreachable()
|
self.set_state(State.UNAVAILABLE)
|
||||||
return
|
return
|
||||||
|
|
||||||
self.state = State.IDLE
|
self.set_state(State.IDLE)
|
||||||
self.jobs_requested += 1
|
self.jobs_requested += 1
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -539,7 +534,6 @@ class Worker:
|
||||||
# this was due to something torch does at startup according to auto and is now done at sdwui startup
|
# this was due to something torch does at startup according to auto and is now done at sdwui startup
|
||||||
for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up"
|
for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up"
|
||||||
if self.state == State.UNAVAILABLE:
|
if self.state == State.UNAVAILABLE:
|
||||||
self.response = None
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
try: # if the worker is unreachable/offline then handle that here
|
try: # if the worker is unreachable/offline then handle that here
|
||||||
|
|
@ -574,7 +568,7 @@ class Worker:
|
||||||
self.response = None
|
self.response = None
|
||||||
self.benchmarked = True
|
self.benchmarked = True
|
||||||
self.eta_percent_error = [] # likely inaccurate after rebenching
|
self.eta_percent_error = [] # likely inaccurate after rebenching
|
||||||
self.state = State.IDLE
|
self.set_state(State.IDLE)
|
||||||
return avg_ipm_result
|
return avg_ipm_result
|
||||||
|
|
||||||
def refresh_checkpoints(self):
|
def refresh_checkpoints(self):
|
||||||
|
|
@ -593,17 +587,17 @@ class Worker:
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
# gradio.Warning("Distributed: "+msg)
|
# gradio.Warning("Distributed: "+msg)
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
self.mark_unreachable()
|
self.set_state(State.UNAVAILABLE)
|
||||||
|
|
||||||
def interrupt(self):
|
def interrupt(self):
|
||||||
try:
|
try:
|
||||||
response = self.session.post(self.full_url('interrupt'))
|
response = self.session.post(self.full_url('interrupt'))
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
self.state = State.INTERRUPTED
|
self.set_state(State.INTERRUPTED)
|
||||||
logger.debug(f"successfully interrupted worker {self.label}")
|
logger.debug(f"successfully interrupted worker {self.label}")
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
self.mark_unreachable()
|
self.set_state(State.UNAVAILABLE)
|
||||||
|
|
||||||
def reachable(self) -> bool:
|
def reachable(self) -> bool:
|
||||||
"""returns false if worker is unreachable"""
|
"""returns false if worker is unreachable"""
|
||||||
|
|
@ -622,18 +616,6 @@ class Worker:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def mark_unreachable(self):
|
|
||||||
if self.state == State.DISABLED:
|
|
||||||
logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable")
|
|
||||||
else:
|
|
||||||
msg = f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection"
|
|
||||||
logger.error(msg)
|
|
||||||
# gradio.Warning("Distributed: "+msg)
|
|
||||||
self.state = State.UNAVAILABLE
|
|
||||||
# invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models
|
|
||||||
self.loaded_model = None
|
|
||||||
self.loaded_vae = None
|
|
||||||
|
|
||||||
def available_models(self) -> [List[str]]:
|
def available_models(self) -> [List[str]]:
|
||||||
if self.state == State.UNAVAILABLE or self.state == State.DISABLED or self.master:
|
if self.state == State.UNAVAILABLE or self.state == State.DISABLED or self.master:
|
||||||
return []
|
return []
|
||||||
|
|
@ -654,7 +636,7 @@ class Worker:
|
||||||
titles = [model['title'] for model in response.json()]
|
titles = [model['title'] for model in response.json()]
|
||||||
return titles
|
return titles
|
||||||
except requests.RequestException:
|
except requests.RequestException:
|
||||||
self.mark_unreachable()
|
self.set_state(State.UNAVAILABLE)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def load_options(self, model, vae=None):
|
def load_options(self, model, vae=None):
|
||||||
|
|
@ -671,14 +653,14 @@ class Worker:
|
||||||
if vae is not None:
|
if vae is not None:
|
||||||
payload['sd_vae'] = vae
|
payload['sd_vae'] = vae
|
||||||
|
|
||||||
self.state = State.WORKING
|
self.set_state(State.WORKING)
|
||||||
start = time.time()
|
start = time.time()
|
||||||
response = self.session.post(
|
response = self.session.post(
|
||||||
self.full_url("options"),
|
self.full_url("options"),
|
||||||
json=payload
|
json=payload
|
||||||
)
|
)
|
||||||
elapsed = time.time() - start
|
elapsed = time.time() - start
|
||||||
self.state = State.IDLE
|
self.set_state(State.IDLE)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
logger.debug(f"failed to load options for worker '{self.label}'")
|
logger.debug(f"failed to load options for worker '{self.label}'")
|
||||||
|
|
@ -720,3 +702,44 @@ class Worker:
|
||||||
|
|
||||||
logger.error(f"{err_msg}: {response}")
|
logger.error(f"{err_msg}: {response}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def set_state(self, state: State, expect_cycle: bool = False):
|
||||||
|
"""
|
||||||
|
Updates the state of a worker if considered a valid operation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: the new state to try transitioning to
|
||||||
|
expect_cycle: whether this transition might be a no-op/self-loop
|
||||||
|
|
||||||
|
"""
|
||||||
|
state_cache = self.state
|
||||||
|
|
||||||
|
def transition(ns: State):
|
||||||
|
if ns == self.state and expect_cycle is False:
|
||||||
|
logger.debug(f"{self.label}: potentially redundant transition {self.state.name} -> {ns.name}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"{self.label}: {self.state.name} -> {ns.name}")
|
||||||
|
self.state = ns
|
||||||
|
|
||||||
|
transitions = {
|
||||||
|
State.IDLE: {State.IDLE, State.WORKING},
|
||||||
|
State.WORKING: {State.WORKING, State.IDLE, State.INTERRUPTED},
|
||||||
|
State.UNAVAILABLE: {State.IDLE},
|
||||||
|
State.INTERRUPTED: {State.WORKING},
|
||||||
|
}
|
||||||
|
if state in transitions.get(self.state, {}):
|
||||||
|
transition(state)
|
||||||
|
|
||||||
|
if state == State.UNAVAILABLE:
|
||||||
|
if self.state == State.DISABLED:
|
||||||
|
logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable")
|
||||||
|
else:
|
||||||
|
logger.error(f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection")
|
||||||
|
# invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models
|
||||||
|
self.loaded_model = None
|
||||||
|
self.loaded_vae = None
|
||||||
|
transition(state)
|
||||||
|
|
||||||
|
if self.state == state_cache and self.state != state:
|
||||||
|
logger.debug(f"{self.label}: invalid transition {self.state.name} -> {state.name}")
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,9 @@ from . import shared as sh
|
||||||
from .pmodels import ConfigModel, Benchmark_Payload
|
from .pmodels import ConfigModel, Benchmark_Payload
|
||||||
from .shared import logger, extension_path
|
from .shared import logger, extension_path
|
||||||
from .worker import Worker, State
|
from .worker import Worker, State
|
||||||
from modules.call_queue import wrap_queued_call
|
from modules.call_queue import wrap_queued_call, queue_lock
|
||||||
from modules import processing
|
from modules import processing
|
||||||
|
from modules import progress
|
||||||
|
|
||||||
|
|
||||||
class NotBenchmarked(Exception):
|
class NotBenchmarked(Exception):
|
||||||
|
|
@ -44,6 +45,7 @@ class Job:
|
||||||
self.batch_size: int = batch_size
|
self.batch_size: int = batch_size
|
||||||
self.complementary: bool = False
|
self.complementary: bool = False
|
||||||
self.step_override = None
|
self.step_override = None
|
||||||
|
self.thread = None
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
prefix = ''
|
prefix = ''
|
||||||
|
|
@ -160,13 +162,7 @@ class World:
|
||||||
return new
|
return new
|
||||||
else:
|
else:
|
||||||
for key in kwargs:
|
for key in kwargs:
|
||||||
if hasattr(original, key):
|
setattr(original, key, kwargs[key])
|
||||||
# TODO only necessary because this is skipping Worker.__init__ and the pyd model is saving the state as an int instead of an actual enum
|
|
||||||
if key == 'state':
|
|
||||||
original.state = kwargs[key] if type(kwargs[key]) is State else State(kwargs[key])
|
|
||||||
continue
|
|
||||||
|
|
||||||
setattr(original, key, kwargs[key])
|
|
||||||
|
|
||||||
return original
|
return original
|
||||||
|
|
||||||
|
|
@ -186,29 +182,22 @@ class World:
|
||||||
Thread(target=worker.refresh_checkpoints, args=()).start()
|
Thread(target=worker.refresh_checkpoints, args=()).start()
|
||||||
|
|
||||||
def sample_master(self) -> float:
|
def sample_master(self) -> float:
|
||||||
# wrap our benchmark payload
|
p = StableDiffusionProcessingTxt2Img()
|
||||||
master_bench_payload = StableDiffusionProcessingTxt2Img()
|
|
||||||
d = sh.benchmark_payload.dict()
|
d = sh.benchmark_payload.dict()
|
||||||
for key in d:
|
for key in d:
|
||||||
setattr(master_bench_payload, key, d[key])
|
setattr(p, key, d[key])
|
||||||
|
p.do_not_save_samples = True
|
||||||
|
|
||||||
# Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to.
|
|
||||||
master_bench_payload.do_not_save_samples = True
|
|
||||||
# shared.state.begin(job='distributed_master_bench')
|
|
||||||
wrapped = (wrap_queued_call(process_images))
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
wrapped(master_bench_payload)
|
process_images(p)
|
||||||
# wrap_gradio_gpu_call(process_images)(master_bench_payload)
|
|
||||||
# shared.state.end()
|
|
||||||
|
|
||||||
return time.time() - start
|
return time.time() - start
|
||||||
|
|
||||||
|
|
||||||
def benchmark(self, rebenchmark: bool = False):
|
def benchmark(self, rebenchmark: bool = False):
|
||||||
"""
|
"""
|
||||||
Attempts to benchmark all workers a part of the world.
|
Attempts to benchmark all workers a part of the world.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
local_task_id = 'task(distributed_bench)'
|
||||||
unbenched_workers = []
|
unbenched_workers = []
|
||||||
if rebenchmark:
|
if rebenchmark:
|
||||||
for worker in self._workers:
|
for worker in self._workers:
|
||||||
|
|
@ -247,26 +236,42 @@ class World:
|
||||||
futures.clear()
|
futures.clear()
|
||||||
|
|
||||||
# benchmark those that haven't been
|
# benchmark those that haven't been
|
||||||
for worker in unbenched_workers:
|
if len(unbenched_workers) > 0:
|
||||||
if worker.state in (State.DISABLED, State.UNAVAILABLE):
|
queue_lock.acquire()
|
||||||
logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark")
|
gradio.Info("Distributed: benchmarking in progress, please wait")
|
||||||
continue
|
for worker in unbenched_workers:
|
||||||
|
if worker.state in (State.DISABLED, State.UNAVAILABLE):
|
||||||
|
logger.debug(f"worker '{worker.label}' is {worker.state.name}, refusing to benchmark")
|
||||||
|
continue
|
||||||
|
|
||||||
if worker.model_override is not None:
|
if worker.model_override is not None:
|
||||||
logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n"
|
logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n"
|
||||||
f"*all workers should be evaluated against the same model")
|
f"*all workers should be evaluated against the same model")
|
||||||
|
|
||||||
chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master)
|
if worker.master:
|
||||||
futures.append(executor.submit(chosen, worker))
|
if progress.current_task is None:
|
||||||
logger.info(f"benchmarking worker '{worker.label}'")
|
progress.add_task_to_queue(local_task_id)
|
||||||
|
progress.start_task(local_task_id)
|
||||||
|
shared.state.begin(job=local_task_id)
|
||||||
|
shared.state.job_count = sh.warmup_samples + sh.samples
|
||||||
|
|
||||||
# wait for all benchmarks to finish and update stats on newly benchmarked workers
|
chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master)
|
||||||
concurrent.futures.wait(futures)
|
futures.append(executor.submit(chosen, worker))
|
||||||
logger.info("benchmarking finished")
|
logger.info(f"benchmarking worker '{worker.label}'")
|
||||||
|
|
||||||
# save benchmark results to workers.json
|
if len(futures) > 0:
|
||||||
self.save_config()
|
# wait for all benchmarks to finish and update stats on newly benchmarked workers
|
||||||
logger.info(self.speed_summary())
|
concurrent.futures.wait(futures)
|
||||||
|
|
||||||
|
if progress.current_task == local_task_id:
|
||||||
|
shared.state.end()
|
||||||
|
progress.finish_task(local_task_id)
|
||||||
|
queue_lock.release()
|
||||||
|
|
||||||
|
logger.info("benchmarking finished")
|
||||||
|
logger.info(self.speed_summary())
|
||||||
|
gradio.Info("Distributed: benchmarking complete!")
|
||||||
|
self.save_config()
|
||||||
|
|
||||||
def get_current_output_size(self) -> int:
|
def get_current_output_size(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
@ -369,7 +374,7 @@ class World:
|
||||||
|
|
||||||
batch_size = self.default_batch_size()
|
batch_size = self.default_batch_size()
|
||||||
for worker in self.get_workers():
|
for worker in self.get_workers():
|
||||||
if worker.state != State.DISABLED and worker.state != State.UNAVAILABLE:
|
if worker.state not in (State.DISABLED, State.UNAVAILABLE):
|
||||||
if worker.avg_ipm is None or worker.avg_ipm <= 0:
|
if worker.avg_ipm is None or worker.avg_ipm <= 0:
|
||||||
logger.debug(f"No recorded speed for worker '{worker.label}, benchmarking'")
|
logger.debug(f"No recorded speed for worker '{worker.label}, benchmarking'")
|
||||||
worker.benchmark()
|
worker.benchmark()
|
||||||
|
|
@ -379,16 +384,16 @@ class World:
|
||||||
|
|
||||||
def update(self, p):
|
def update(self, p):
|
||||||
"""preps world for another run"""
|
"""preps world for another run"""
|
||||||
if not self.initialized:
|
|
||||||
self.benchmark()
|
|
||||||
self.initialized = True
|
|
||||||
logger.debug("world initialized!")
|
|
||||||
else:
|
|
||||||
logger.debug("world was already initialized")
|
|
||||||
|
|
||||||
self.p = p
|
self.p = p
|
||||||
|
self.benchmark()
|
||||||
self.make_jobs()
|
self.make_jobs()
|
||||||
|
|
||||||
|
if not self.initialized:
|
||||||
|
self.initialized = True
|
||||||
|
logger.debug("world initialized!")
|
||||||
|
|
||||||
|
|
||||||
def get_workers(self):
|
def get_workers(self):
|
||||||
filtered: List[Worker] = []
|
filtered: List[Worker] = []
|
||||||
for worker in self._workers:
|
for worker in self._workers:
|
||||||
|
|
@ -397,7 +402,7 @@ class World:
|
||||||
continue
|
continue
|
||||||
if worker.master and self.thin_client_mode:
|
if worker.master and self.thin_client_mode:
|
||||||
continue
|
continue
|
||||||
if worker.state != State.UNAVAILABLE and worker.state != State.DISABLED:
|
if worker.state not in (State.UNAVAILABLE, State.DISABLED):
|
||||||
filtered.append(worker)
|
filtered.append(worker)
|
||||||
|
|
||||||
return filtered
|
return filtered
|
||||||
|
|
@ -535,7 +540,7 @@ class World:
|
||||||
seconds_per_sample = job.worker.eta(payload=payload, batch_size=1, samples=1)
|
seconds_per_sample = job.worker.eta(payload=payload, batch_size=1, samples=1)
|
||||||
realtime_samples = slack_time // seconds_per_sample
|
realtime_samples = slack_time // seconds_per_sample
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"job for '{job.worker.label}' downscaled to {realtime_samples} samples to meet time constraints\n"
|
f"job for '{job.worker.label}' downscaled to {realtime_samples:.0f} samples to meet time constraints\n"
|
||||||
f"{realtime_samples:.0f} samples = {slack_time:.2f}s slack ÷ {seconds_per_sample:.2f}s/sample\n"
|
f"{realtime_samples:.0f} samples = {slack_time:.2f}s slack ÷ {seconds_per_sample:.2f}s/sample\n"
|
||||||
f" step reduction: {payload['steps']} -> {realtime_samples:.0f}"
|
f" step reduction: {payload['steps']} -> {realtime_samples:.0f}"
|
||||||
)
|
)
|
||||||
|
|
@ -546,17 +551,7 @@ class World:
|
||||||
else:
|
else:
|
||||||
logger.debug("complementary image production is disabled")
|
logger.debug("complementary image production is disabled")
|
||||||
|
|
||||||
iterations = payload['n_iter']
|
logger.info(self.distro_summary(payload))
|
||||||
num_returning = self.get_current_output_size()
|
|
||||||
num_complementary = num_returning - self.p.batch_size
|
|
||||||
distro_summary = "Job distribution:\n"
|
|
||||||
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
|
|
||||||
if num_complementary > 0:
|
|
||||||
distro_summary += f" + {num_complementary} complementary"
|
|
||||||
distro_summary += f": {num_returning} images total\n"
|
|
||||||
for job in self.jobs:
|
|
||||||
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
|
|
||||||
logger.info(distro_summary)
|
|
||||||
|
|
||||||
if self.thin_client_mode is True or self.master_job().batch_size == 0:
|
if self.thin_client_mode is True or self.master_job().batch_size == 0:
|
||||||
# save original process_images_inner for later so we can restore once we're done
|
# save original process_images_inner for later so we can restore once we're done
|
||||||
|
|
@ -581,6 +576,20 @@ class World:
|
||||||
del self.jobs[last]
|
del self.jobs[last]
|
||||||
last -= 1
|
last -= 1
|
||||||
|
|
||||||
|
def distro_summary(self, payload):
|
||||||
|
# iterations = dict(payload)['n_iter']
|
||||||
|
iterations = self.p.n_iter
|
||||||
|
num_returning = self.get_current_output_size()
|
||||||
|
num_complementary = num_returning - self.p.batch_size
|
||||||
|
distro_summary = "Job distribution:\n"
|
||||||
|
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
|
||||||
|
if num_complementary > 0:
|
||||||
|
distro_summary += f" + {num_complementary} complementary"
|
||||||
|
distro_summary += f": {num_returning} images total\n"
|
||||||
|
for job in self.jobs:
|
||||||
|
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
|
||||||
|
return distro_summary
|
||||||
|
|
||||||
def config(self) -> dict:
|
def config(self) -> dict:
|
||||||
"""
|
"""
|
||||||
{
|
{
|
||||||
|
|
@ -654,6 +663,10 @@ class World:
|
||||||
fields['label'] = label
|
fields['label'] = label
|
||||||
# TODO must be overridden everytime here or later converted to a config file variable at some point
|
# TODO must be overridden everytime here or later converted to a config file variable at some point
|
||||||
fields['verify_remotes'] = self.verify_remotes
|
fields['verify_remotes'] = self.verify_remotes
|
||||||
|
# cast enum id to actual enum type and then prime state
|
||||||
|
fields['state'] = State(fields['state'])
|
||||||
|
if fields['state'] not in (State.DISABLED, State.UNAVAILABLE):
|
||||||
|
fields['state'] = State.IDLE
|
||||||
|
|
||||||
self.add_worker(**fields)
|
self.add_worker(**fields)
|
||||||
|
|
||||||
|
|
@ -663,11 +676,11 @@ class World:
|
||||||
self.complement_production = config.complement_production
|
self.complement_production = config.complement_production
|
||||||
self.step_scaling = config.step_scaling
|
self.step_scaling = config.step_scaling
|
||||||
|
|
||||||
logger.debug("config loaded")
|
logger.debug(f"config loaded from '{os.path.abspath(self.config_path)}'")
|
||||||
|
|
||||||
def save_config(self):
|
def save_config(self):
|
||||||
"""
|
"""
|
||||||
Saves the config file.
|
Saves current state to the config file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config = ConfigModel(
|
config = ConfigModel(
|
||||||
|
|
@ -727,11 +740,14 @@ class World:
|
||||||
msg = f"worker '{worker.label}' is online"
|
msg = f"worker '{worker.label}' is online"
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
gradio.Info("Distributed: "+msg)
|
gradio.Info("Distributed: "+msg)
|
||||||
worker.state = State.IDLE
|
worker.set_state(State.IDLE, expect_cycle=True)
|
||||||
else:
|
else:
|
||||||
msg = f"worker '{worker.label}' is unreachable"
|
msg = f"worker '{worker.label}' is unreachable"
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
gradio.Warning("Distributed: "+msg)
|
gradio.Warning("Distributed: "+msg)
|
||||||
|
worker.set_state(State.UNAVAILABLE)
|
||||||
|
|
||||||
|
self.save_config()
|
||||||
|
|
||||||
def restart_all(self):
|
def restart_all(self):
|
||||||
for worker in self._workers:
|
for worker in self._workers:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue