merge dev making 2.2.2

master v2.2.2
papuSpartan 2024-08-30 12:11:38 -05:00
commit 65b0eb75e0
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
5 changed files with 157 additions and 102 deletions

View File

@ -1,6 +1,12 @@
# Change Log # Change Log
Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html) Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html)
## [2.2.2] - 2024-8-30
### Fixed
- Unavailable state sometimes being ignored
## [2.2.1] - 2024-5-16 ## [2.2.1] - 2024-5-16
### Fixed ### Fixed

View File

@ -24,7 +24,7 @@ from modules.shared import state as webui_state
from scripts.spartan.control_net import pack_control_net from scripts.spartan.control_net import pack_control_net
from scripts.spartan.shared import logger from scripts.spartan.shared import logger
from scripts.spartan.ui import UI from scripts.spartan.ui import UI
from scripts.spartan.world import World from scripts.spartan.world import World, State
old_sigint_handler = signal.getsignal(signal.SIGINT) old_sigint_handler = signal.getsignal(signal.SIGINT)
old_sigterm_handler = signal.getsignal(signal.SIGTERM) old_sigterm_handler = signal.getsignal(signal.SIGTERM)
@ -33,7 +33,6 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM)
# noinspection PyMissingOrEmptyDocstring # noinspection PyMissingOrEmptyDocstring
class DistributedScript(scripts.Script): class DistributedScript(scripts.Script):
# global old_sigterm_handler, old_sigterm_handler # global old_sigterm_handler, old_sigterm_handler
worker_threads: List[Thread] = []
# Whether to verify worker certificates. Can be useful if your remotes are self-signed. # Whether to verify worker certificates. Can be useful if your remotes are self-signed.
verify_remotes = not cmd_opts.distributed_skip_verify_remotes verify_remotes = not cmd_opts.distributed_skip_verify_remotes
master_start = None master_start = None
@ -150,17 +149,19 @@ class DistributedScript(scripts.Script):
# wait for response from all workers # wait for response from all workers
webui_state.textinfo = "Distributed - receiving results" webui_state.textinfo = "Distributed - receiving results"
for thread in self.worker_threads: for job in self.world.jobs:
logger.debug(f"waiting for worker thread '{thread.name}'") if job.thread is None:
thread.join() continue
self.worker_threads.clear()
logger.debug(f"waiting for worker thread '{job.thread.name}'")
job.thread.join()
logger.debug("all worker request threads returned") logger.debug("all worker request threads returned")
webui_state.textinfo = "Distributed - injecting images" webui_state.textinfo = "Distributed - injecting images"
# some worker which we know has a good response that we can use for generating the grid # some worker which we know has a good response that we can use for generating the grid
donor_worker = None donor_worker = None
for job in self.world.jobs: for job in self.world.jobs:
if job.batch_size < 1 or job.worker.master: if job.worker.response is None or job.batch_size < 1 or job.worker.master:
continue continue
try: try:
@ -197,7 +198,7 @@ class DistributedScript(scripts.Script):
return return
# generate and inject grid # generate and inject grid
if opts.return_grid: if opts.return_grid and len(processed.images) > 1:
grid = image_grid(processed.images, len(processed.images)) grid = image_grid(processed.images, len(processed.images))
processed_inject_image( processed_inject_image(
image=grid, image=grid,
@ -304,6 +305,9 @@ class DistributedScript(scripts.Script):
return return
for job in self.world.jobs: for job in self.world.jobs:
if job.worker.state in (State.UNAVAILABLE, State.DISABLED):
continue
payload_temp = copy.copy(payload) payload_temp = copy.copy(payload)
del payload_temp['scripts_value'] del payload_temp['scripts_value']
payload_temp = copy.deepcopy(payload_temp) payload_temp = copy.deepcopy(payload_temp)
@ -332,11 +336,9 @@ class DistributedScript(scripts.Script):
job.worker.loaded_model = name job.worker.loaded_model = name
job.worker.loaded_vae = vae job.worker.loaded_vae = vae
t = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,), job.thread = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,),
name=f"{job.worker.label}_request") name=f"{job.worker.label}_request")
job.thread.start()
t.start()
self.worker_threads.append(t)
started_jobs.append(job) started_jobs.append(job)
# if master batch size was changed again due to optimization change it to the updated value # if master batch size was changed again due to optimization change it to the updated value
@ -358,6 +360,8 @@ class DistributedScript(scripts.Script):
# restore process_images_inner if it was monkey-patched # restore process_images_inner if it was monkey-patched
processing.process_images_inner = self.original_process_images_inner processing.process_images_inner = self.original_process_images_inner
# save any dangling state to prevent load_config in next iteration overwriting it
self.world.save_config()
@staticmethod @staticmethod
def signal_handler(sig, frame): def signal_handler(sig, frame):

View File

@ -7,6 +7,8 @@ from modules.shared import opts
from modules.shared import state as webui_state from modules.shared import state as webui_state
from .shared import logger, LOG_LEVEL, gui_handler from .shared import logger, LOG_LEVEL, gui_handler
from .worker import State from .worker import State
from modules.call_queue import queue_lock
from modules import progress
worker_select_dropdown = None worker_select_dropdown = None
@ -61,6 +63,10 @@ class UI:
"""debug utility that will clear the internal webui queue. sometimes good for jams""" """debug utility that will clear the internal webui queue. sometimes good for jams"""
logger.debug(webui_state.__dict__) logger.debug(webui_state.__dict__)
webui_state.end() webui_state.end()
progress.pending_tasks.clear()
progress.current_task = None
if queue_lock._lock.locked():
queue_lock.release()
def status_btn(self): def status_btn(self):
"""updates a simplified overview of registered workers and their jobs""" """updates a simplified overview of registered workers and their jobs"""

View File

@ -306,7 +306,7 @@ class Worker:
if waited >= (0.85 * max_wait): if waited >= (0.85 * max_wait):
logger.warning("this seems long, so if you see this message often, consider reporting an issue") logger.warning("this seems long, so if you see this message often, consider reporting an issue")
self.state = State.WORKING self.set_state(State.WORKING)
# query memory available on worker and store for future reference # query memory available on worker and store for future reference
if self.queried is False: if self.queried is False:
@ -344,7 +344,7 @@ class Worker:
# remove anything that is not serializable # remove anything that is not serializable
# s_tmax can be float('inf') which is not serializable, so we convert it to the max float value # s_tmax can be float('inf') which is not serializable, so we convert it to the max float value
s_tmax = payload.get('s_tmax', 0.0) s_tmax = payload.get('s_tmax', 0.0)
if s_tmax > 1e308: if s_tmax is not None and s_tmax > 1e308:
payload['s_tmax'] = 1e308 payload['s_tmax'] = 1e308
# remove unserializable caches # remove unserializable caches
payload.pop('cached_uc', None) payload.pop('cached_uc', None)
@ -423,7 +423,6 @@ class Worker:
sampler_name = payload.get('sampler_name', None) sampler_name = payload.get('sampler_name', None)
if sampler_index is None: if sampler_index is None:
if sampler_name is not None: if sampler_name is not None:
logger.debug("had to substitute sampler index with name")
payload['sampler_index'] = sampler_name payload['sampler_index'] = sampler_name
try: try:
@ -490,18 +489,14 @@ class Worker:
logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n") logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
except Exception as e: except Exception as e:
self.state = State.IDLE self.set_state(State.IDLE)
raise InvalidWorkerResponse(e)
if payload['batch_size'] == 0:
raise InvalidWorkerResponse("Tried to request a null amount of images")
else:
raise InvalidWorkerResponse(e)
except requests.RequestException: except requests.RequestException:
self.mark_unreachable() self.set_state(State.UNAVAILABLE)
return return
self.state = State.IDLE self.set_state(State.IDLE)
self.jobs_requested += 1 self.jobs_requested += 1
return return
@ -539,7 +534,6 @@ class Worker:
# this was due to something torch does at startup according to auto and is now done at sdwui startup # this was due to something torch does at startup according to auto and is now done at sdwui startup
for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up" for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up"
if self.state == State.UNAVAILABLE: if self.state == State.UNAVAILABLE:
self.response = None
return 0 return 0
try: # if the worker is unreachable/offline then handle that here try: # if the worker is unreachable/offline then handle that here
@ -574,7 +568,7 @@ class Worker:
self.response = None self.response = None
self.benchmarked = True self.benchmarked = True
self.eta_percent_error = [] # likely inaccurate after rebenching self.eta_percent_error = [] # likely inaccurate after rebenching
self.state = State.IDLE self.set_state(State.IDLE)
return avg_ipm_result return avg_ipm_result
def refresh_checkpoints(self): def refresh_checkpoints(self):
@ -593,17 +587,17 @@ class Worker:
logger.error(msg) logger.error(msg)
# gradio.Warning("Distributed: "+msg) # gradio.Warning("Distributed: "+msg)
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
self.mark_unreachable() self.set_state(State.UNAVAILABLE)
def interrupt(self): def interrupt(self):
try: try:
response = self.session.post(self.full_url('interrupt')) response = self.session.post(self.full_url('interrupt'))
if response.status_code == 200: if response.status_code == 200:
self.state = State.INTERRUPTED self.set_state(State.INTERRUPTED)
logger.debug(f"successfully interrupted worker {self.label}") logger.debug(f"successfully interrupted worker {self.label}")
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
self.mark_unreachable() self.set_state(State.UNAVAILABLE)
def reachable(self) -> bool: def reachable(self) -> bool:
"""returns false if worker is unreachable""" """returns false if worker is unreachable"""
@ -622,18 +616,6 @@ class Worker:
logger.error(e) logger.error(e)
return False return False
def mark_unreachable(self):
if self.state == State.DISABLED:
logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable")
else:
msg = f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection"
logger.error(msg)
# gradio.Warning("Distributed: "+msg)
self.state = State.UNAVAILABLE
# invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models
self.loaded_model = None
self.loaded_vae = None
def available_models(self) -> [List[str]]: def available_models(self) -> [List[str]]:
if self.state == State.UNAVAILABLE or self.state == State.DISABLED or self.master: if self.state == State.UNAVAILABLE or self.state == State.DISABLED or self.master:
return [] return []
@ -654,7 +636,7 @@ class Worker:
titles = [model['title'] for model in response.json()] titles = [model['title'] for model in response.json()]
return titles return titles
except requests.RequestException: except requests.RequestException:
self.mark_unreachable() self.set_state(State.UNAVAILABLE)
return [] return []
def load_options(self, model, vae=None): def load_options(self, model, vae=None):
@ -671,14 +653,14 @@ class Worker:
if vae is not None: if vae is not None:
payload['sd_vae'] = vae payload['sd_vae'] = vae
self.state = State.WORKING self.set_state(State.WORKING)
start = time.time() start = time.time()
response = self.session.post( response = self.session.post(
self.full_url("options"), self.full_url("options"),
json=payload json=payload
) )
elapsed = time.time() - start elapsed = time.time() - start
self.state = State.IDLE self.set_state(State.IDLE)
if response.status_code != 200: if response.status_code != 200:
logger.debug(f"failed to load options for worker '{self.label}'") logger.debug(f"failed to load options for worker '{self.label}'")
@ -720,3 +702,44 @@ class Worker:
logger.error(f"{err_msg}: {response}") logger.error(f"{err_msg}: {response}")
return False return False
def set_state(self, state: State, expect_cycle: bool = False):
"""
Updates the state of a worker if considered a valid operation
Args:
state: the new state to try transitioning to
expect_cycle: whether this transition might be a no-op/self-loop
"""
state_cache = self.state
def transition(ns: State):
if ns == self.state and expect_cycle is False:
logger.debug(f"{self.label}: potentially redundant transition {self.state.name} -> {ns.name}")
return
logger.debug(f"{self.label}: {self.state.name} -> {ns.name}")
self.state = ns
transitions = {
State.IDLE: {State.IDLE, State.WORKING},
State.WORKING: {State.WORKING, State.IDLE, State.INTERRUPTED},
State.UNAVAILABLE: {State.IDLE},
State.INTERRUPTED: {State.WORKING},
}
if state in transitions.get(self.state, {}):
transition(state)
if state == State.UNAVAILABLE:
if self.state == State.DISABLED:
logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable")
else:
logger.error(f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection")
# invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models
self.loaded_model = None
self.loaded_vae = None
transition(state)
if self.state == state_cache and self.state != state:
logger.debug(f"{self.label}: invalid transition {self.state.name} -> {state.name}")

View File

@ -18,8 +18,9 @@ from . import shared as sh
from .pmodels import ConfigModel, Benchmark_Payload from .pmodels import ConfigModel, Benchmark_Payload
from .shared import logger, extension_path from .shared import logger, extension_path
from .worker import Worker, State from .worker import Worker, State
from modules.call_queue import wrap_queued_call from modules.call_queue import wrap_queued_call, queue_lock
from modules import processing from modules import processing
from modules import progress
class NotBenchmarked(Exception): class NotBenchmarked(Exception):
@ -44,6 +45,7 @@ class Job:
self.batch_size: int = batch_size self.batch_size: int = batch_size
self.complementary: bool = False self.complementary: bool = False
self.step_override = None self.step_override = None
self.thread = None
def __str__(self): def __str__(self):
prefix = '' prefix = ''
@ -160,13 +162,7 @@ class World:
return new return new
else: else:
for key in kwargs: for key in kwargs:
if hasattr(original, key): setattr(original, key, kwargs[key])
# TODO only necessary because this is skipping Worker.__init__ and the pyd model is saving the state as an int instead of an actual enum
if key == 'state':
original.state = kwargs[key] if type(kwargs[key]) is State else State(kwargs[key])
continue
setattr(original, key, kwargs[key])
return original return original
@ -186,29 +182,22 @@ class World:
Thread(target=worker.refresh_checkpoints, args=()).start() Thread(target=worker.refresh_checkpoints, args=()).start()
def sample_master(self) -> float: def sample_master(self) -> float:
# wrap our benchmark payload p = StableDiffusionProcessingTxt2Img()
master_bench_payload = StableDiffusionProcessingTxt2Img()
d = sh.benchmark_payload.dict() d = sh.benchmark_payload.dict()
for key in d: for key in d:
setattr(master_bench_payload, key, d[key]) setattr(p, key, d[key])
p.do_not_save_samples = True
# Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to.
master_bench_payload.do_not_save_samples = True
# shared.state.begin(job='distributed_master_bench')
wrapped = (wrap_queued_call(process_images))
start = time.time() start = time.time()
wrapped(master_bench_payload) process_images(p)
# wrap_gradio_gpu_call(process_images)(master_bench_payload)
# shared.state.end()
return time.time() - start return time.time() - start
def benchmark(self, rebenchmark: bool = False): def benchmark(self, rebenchmark: bool = False):
""" """
Attempts to benchmark all workers a part of the world. Attempts to benchmark all workers a part of the world.
""" """
local_task_id = 'task(distributed_bench)'
unbenched_workers = [] unbenched_workers = []
if rebenchmark: if rebenchmark:
for worker in self._workers: for worker in self._workers:
@ -247,26 +236,42 @@ class World:
futures.clear() futures.clear()
# benchmark those that haven't been # benchmark those that haven't been
for worker in unbenched_workers: if len(unbenched_workers) > 0:
if worker.state in (State.DISABLED, State.UNAVAILABLE): queue_lock.acquire()
logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark") gradio.Info("Distributed: benchmarking in progress, please wait")
continue for worker in unbenched_workers:
if worker.state in (State.DISABLED, State.UNAVAILABLE):
logger.debug(f"worker '{worker.label}' is {worker.state.name}, refusing to benchmark")
continue
if worker.model_override is not None: if worker.model_override is not None:
logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n" logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n"
f"*all workers should be evaluated against the same model") f"*all workers should be evaluated against the same model")
chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master) if worker.master:
futures.append(executor.submit(chosen, worker)) if progress.current_task is None:
logger.info(f"benchmarking worker '{worker.label}'") progress.add_task_to_queue(local_task_id)
progress.start_task(local_task_id)
shared.state.begin(job=local_task_id)
shared.state.job_count = sh.warmup_samples + sh.samples
# wait for all benchmarks to finish and update stats on newly benchmarked workers chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master)
concurrent.futures.wait(futures) futures.append(executor.submit(chosen, worker))
logger.info("benchmarking finished") logger.info(f"benchmarking worker '{worker.label}'")
# save benchmark results to workers.json if len(futures) > 0:
self.save_config() # wait for all benchmarks to finish and update stats on newly benchmarked workers
logger.info(self.speed_summary()) concurrent.futures.wait(futures)
if progress.current_task == local_task_id:
shared.state.end()
progress.finish_task(local_task_id)
queue_lock.release()
logger.info("benchmarking finished")
logger.info(self.speed_summary())
gradio.Info("Distributed: benchmarking complete!")
self.save_config()
def get_current_output_size(self) -> int: def get_current_output_size(self) -> int:
""" """
@ -369,7 +374,7 @@ class World:
batch_size = self.default_batch_size() batch_size = self.default_batch_size()
for worker in self.get_workers(): for worker in self.get_workers():
if worker.state != State.DISABLED and worker.state != State.UNAVAILABLE: if worker.state not in (State.DISABLED, State.UNAVAILABLE):
if worker.avg_ipm is None or worker.avg_ipm <= 0: if worker.avg_ipm is None or worker.avg_ipm <= 0:
logger.debug(f"No recorded speed for worker '{worker.label}, benchmarking'") logger.debug(f"No recorded speed for worker '{worker.label}, benchmarking'")
worker.benchmark() worker.benchmark()
@ -379,16 +384,16 @@ class World:
def update(self, p): def update(self, p):
"""preps world for another run""" """preps world for another run"""
if not self.initialized:
self.benchmark()
self.initialized = True
logger.debug("world initialized!")
else:
logger.debug("world was already initialized")
self.p = p self.p = p
self.benchmark()
self.make_jobs() self.make_jobs()
if not self.initialized:
self.initialized = True
logger.debug("world initialized!")
def get_workers(self): def get_workers(self):
filtered: List[Worker] = [] filtered: List[Worker] = []
for worker in self._workers: for worker in self._workers:
@ -397,7 +402,7 @@ class World:
continue continue
if worker.master and self.thin_client_mode: if worker.master and self.thin_client_mode:
continue continue
if worker.state != State.UNAVAILABLE and worker.state != State.DISABLED: if worker.state not in (State.UNAVAILABLE, State.DISABLED):
filtered.append(worker) filtered.append(worker)
return filtered return filtered
@ -535,7 +540,7 @@ class World:
seconds_per_sample = job.worker.eta(payload=payload, batch_size=1, samples=1) seconds_per_sample = job.worker.eta(payload=payload, batch_size=1, samples=1)
realtime_samples = slack_time // seconds_per_sample realtime_samples = slack_time // seconds_per_sample
logger.debug( logger.debug(
f"job for '{job.worker.label}' downscaled to {realtime_samples} samples to meet time constraints\n" f"job for '{job.worker.label}' downscaled to {realtime_samples:.0f} samples to meet time constraints\n"
f"{realtime_samples:.0f} samples = {slack_time:.2f}s slack ÷ {seconds_per_sample:.2f}s/sample\n" f"{realtime_samples:.0f} samples = {slack_time:.2f}s slack ÷ {seconds_per_sample:.2f}s/sample\n"
f" step reduction: {payload['steps']} -> {realtime_samples:.0f}" f" step reduction: {payload['steps']} -> {realtime_samples:.0f}"
) )
@ -546,17 +551,7 @@ class World:
else: else:
logger.debug("complementary image production is disabled") logger.debug("complementary image production is disabled")
iterations = payload['n_iter'] logger.info(self.distro_summary(payload))
num_returning = self.get_current_output_size()
num_complementary = num_returning - self.p.batch_size
distro_summary = "Job distribution:\n"
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
if num_complementary > 0:
distro_summary += f" + {num_complementary} complementary"
distro_summary += f": {num_returning} images total\n"
for job in self.jobs:
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
logger.info(distro_summary)
if self.thin_client_mode is True or self.master_job().batch_size == 0: if self.thin_client_mode is True or self.master_job().batch_size == 0:
# save original process_images_inner for later so we can restore once we're done # save original process_images_inner for later so we can restore once we're done
@ -581,6 +576,20 @@ class World:
del self.jobs[last] del self.jobs[last]
last -= 1 last -= 1
def distro_summary(self, payload):
# iterations = dict(payload)['n_iter']
iterations = self.p.n_iter
num_returning = self.get_current_output_size()
num_complementary = num_returning - self.p.batch_size
distro_summary = "Job distribution:\n"
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
if num_complementary > 0:
distro_summary += f" + {num_complementary} complementary"
distro_summary += f": {num_returning} images total\n"
for job in self.jobs:
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
return distro_summary
def config(self) -> dict: def config(self) -> dict:
""" """
{ {
@ -654,6 +663,10 @@ class World:
fields['label'] = label fields['label'] = label
# TODO must be overridden everytime here or later converted to a config file variable at some point # TODO must be overridden everytime here or later converted to a config file variable at some point
fields['verify_remotes'] = self.verify_remotes fields['verify_remotes'] = self.verify_remotes
# cast enum id to actual enum type and then prime state
fields['state'] = State(fields['state'])
if fields['state'] not in (State.DISABLED, State.UNAVAILABLE):
fields['state'] = State.IDLE
self.add_worker(**fields) self.add_worker(**fields)
@ -663,11 +676,11 @@ class World:
self.complement_production = config.complement_production self.complement_production = config.complement_production
self.step_scaling = config.step_scaling self.step_scaling = config.step_scaling
logger.debug("config loaded") logger.debug(f"config loaded from '{os.path.abspath(self.config_path)}'")
def save_config(self): def save_config(self):
""" """
Saves the config file. Saves current state to the config file.
""" """
config = ConfigModel( config = ConfigModel(
@ -727,11 +740,14 @@ class World:
msg = f"worker '{worker.label}' is online" msg = f"worker '{worker.label}' is online"
logger.info(msg) logger.info(msg)
gradio.Info("Distributed: "+msg) gradio.Info("Distributed: "+msg)
worker.state = State.IDLE worker.set_state(State.IDLE, expect_cycle=True)
else: else:
msg = f"worker '{worker.label}' is unreachable" msg = f"worker '{worker.label}' is unreachable"
logger.info(msg) logger.info(msg)
gradio.Warning("Distributed: "+msg) gradio.Warning("Distributed: "+msg)
worker.set_state(State.UNAVAILABLE)
self.save_config()
def restart_all(self): def restart_all(self):
for worker in self._workers: for worker in self._workers: