merge dev making 2.2.2

master v2.2.2
papuSpartan 2024-08-30 12:11:38 -05:00
commit 65b0eb75e0
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
5 changed files with 157 additions and 102 deletions

View File

@ -1,6 +1,12 @@
# Change Log
Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html)
## [2.2.2] - 2024-8-30
### Fixed
- Unavailable state sometimes being ignored
## [2.2.1] - 2024-5-16
### Fixed

View File

@ -24,7 +24,7 @@ from modules.shared import state as webui_state
from scripts.spartan.control_net import pack_control_net
from scripts.spartan.shared import logger
from scripts.spartan.ui import UI
from scripts.spartan.world import World
from scripts.spartan.world import World, State
old_sigint_handler = signal.getsignal(signal.SIGINT)
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
@ -33,7 +33,6 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM)
# noinspection PyMissingOrEmptyDocstring
class DistributedScript(scripts.Script):
# global old_sigterm_handler, old_sigterm_handler
worker_threads: List[Thread] = []
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
verify_remotes = not cmd_opts.distributed_skip_verify_remotes
master_start = None
@ -150,17 +149,19 @@ class DistributedScript(scripts.Script):
# wait for response from all workers
webui_state.textinfo = "Distributed - receiving results"
for thread in self.worker_threads:
logger.debug(f"waiting for worker thread '{thread.name}'")
thread.join()
self.worker_threads.clear()
for job in self.world.jobs:
if job.thread is None:
continue
logger.debug(f"waiting for worker thread '{job.thread.name}'")
job.thread.join()
logger.debug("all worker request threads returned")
webui_state.textinfo = "Distributed - injecting images"
# some worker which we know has a good response that we can use for generating the grid
donor_worker = None
for job in self.world.jobs:
if job.batch_size < 1 or job.worker.master:
if job.worker.response is None or job.batch_size < 1 or job.worker.master:
continue
try:
@ -197,7 +198,7 @@ class DistributedScript(scripts.Script):
return
# generate and inject grid
if opts.return_grid:
if opts.return_grid and len(processed.images) > 1:
grid = image_grid(processed.images, len(processed.images))
processed_inject_image(
image=grid,
@ -304,6 +305,9 @@ class DistributedScript(scripts.Script):
return
for job in self.world.jobs:
if job.worker.state in (State.UNAVAILABLE, State.DISABLED):
continue
payload_temp = copy.copy(payload)
del payload_temp['scripts_value']
payload_temp = copy.deepcopy(payload_temp)
@ -332,11 +336,9 @@ class DistributedScript(scripts.Script):
job.worker.loaded_model = name
job.worker.loaded_vae = vae
t = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,),
job.thread = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,),
name=f"{job.worker.label}_request")
t.start()
self.worker_threads.append(t)
job.thread.start()
started_jobs.append(job)
# if master batch size was changed again due to optimization change it to the updated value
@ -358,6 +360,8 @@ class DistributedScript(scripts.Script):
# restore process_images_inner if it was monkey-patched
processing.process_images_inner = self.original_process_images_inner
# save any dangling state to prevent load_config in next iteration overwriting it
self.world.save_config()
@staticmethod
def signal_handler(sig, frame):

View File

@ -7,6 +7,8 @@ from modules.shared import opts
from modules.shared import state as webui_state
from .shared import logger, LOG_LEVEL, gui_handler
from .worker import State
from modules.call_queue import queue_lock
from modules import progress
worker_select_dropdown = None
@ -61,6 +63,10 @@ class UI:
"""debug utility that will clear the internal webui queue. sometimes good for jams"""
logger.debug(webui_state.__dict__)
webui_state.end()
progress.pending_tasks.clear()
progress.current_task = None
if queue_lock._lock.locked():
queue_lock.release()
def status_btn(self):
"""updates a simplified overview of registered workers and their jobs"""

View File

@ -306,7 +306,7 @@ class Worker:
if waited >= (0.85 * max_wait):
logger.warning("this seems long, so if you see this message often, consider reporting an issue")
self.state = State.WORKING
self.set_state(State.WORKING)
# query memory available on worker and store for future reference
if self.queried is False:
@ -344,7 +344,7 @@ class Worker:
# remove anything that is not serializable
# s_tmax can be float('inf') which is not serializable, so we convert it to the max float value
s_tmax = payload.get('s_tmax', 0.0)
if s_tmax > 1e308:
if s_tmax is not None and s_tmax > 1e308:
payload['s_tmax'] = 1e308
# remove unserializable caches
payload.pop('cached_uc', None)
@ -423,7 +423,6 @@ class Worker:
sampler_name = payload.get('sampler_name', None)
if sampler_index is None:
if sampler_name is not None:
logger.debug("had to substitute sampler index with name")
payload['sampler_index'] = sampler_name
try:
@ -490,18 +489,14 @@ class Worker:
logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
except Exception as e:
self.state = State.IDLE
if payload['batch_size'] == 0:
raise InvalidWorkerResponse("Tried to request a null amount of images")
else:
self.set_state(State.IDLE)
raise InvalidWorkerResponse(e)
except requests.RequestException:
self.mark_unreachable()
self.set_state(State.UNAVAILABLE)
return
self.state = State.IDLE
self.set_state(State.IDLE)
self.jobs_requested += 1
return
@ -539,7 +534,6 @@ class Worker:
# this was due to something torch does at startup according to auto and is now done at sdwui startup
for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up"
if self.state == State.UNAVAILABLE:
self.response = None
return 0
try: # if the worker is unreachable/offline then handle that here
@ -574,7 +568,7 @@ class Worker:
self.response = None
self.benchmarked = True
self.eta_percent_error = [] # likely inaccurate after rebenching
self.state = State.IDLE
self.set_state(State.IDLE)
return avg_ipm_result
def refresh_checkpoints(self):
@ -593,17 +587,17 @@ class Worker:
logger.error(msg)
# gradio.Warning("Distributed: "+msg)
except requests.exceptions.ConnectionError:
self.mark_unreachable()
self.set_state(State.UNAVAILABLE)
def interrupt(self):
try:
response = self.session.post(self.full_url('interrupt'))
if response.status_code == 200:
self.state = State.INTERRUPTED
self.set_state(State.INTERRUPTED)
logger.debug(f"successfully interrupted worker {self.label}")
except requests.exceptions.ConnectionError:
self.mark_unreachable()
self.set_state(State.UNAVAILABLE)
def reachable(self) -> bool:
"""returns false if worker is unreachable"""
@ -622,18 +616,6 @@ class Worker:
logger.error(e)
return False
def mark_unreachable(self):
if self.state == State.DISABLED:
logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable")
else:
msg = f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection"
logger.error(msg)
# gradio.Warning("Distributed: "+msg)
self.state = State.UNAVAILABLE
# invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models
self.loaded_model = None
self.loaded_vae = None
def available_models(self) -> [List[str]]:
if self.state == State.UNAVAILABLE or self.state == State.DISABLED or self.master:
return []
@ -654,7 +636,7 @@ class Worker:
titles = [model['title'] for model in response.json()]
return titles
except requests.RequestException:
self.mark_unreachable()
self.set_state(State.UNAVAILABLE)
return []
def load_options(self, model, vae=None):
@ -671,14 +653,14 @@ class Worker:
if vae is not None:
payload['sd_vae'] = vae
self.state = State.WORKING
self.set_state(State.WORKING)
start = time.time()
response = self.session.post(
self.full_url("options"),
json=payload
)
elapsed = time.time() - start
self.state = State.IDLE
self.set_state(State.IDLE)
if response.status_code != 200:
logger.debug(f"failed to load options for worker '{self.label}'")
@ -720,3 +702,44 @@ class Worker:
logger.error(f"{err_msg}: {response}")
return False
def set_state(self, state: State, expect_cycle: bool = False):
"""
Updates the state of a worker if considered a valid operation
Args:
state: the new state to try transitioning to
expect_cycle: whether this transition might be a no-op/self-loop
"""
state_cache = self.state
def transition(ns: State):
if ns == self.state and expect_cycle is False:
logger.debug(f"{self.label}: potentially redundant transition {self.state.name} -> {ns.name}")
return
logger.debug(f"{self.label}: {self.state.name} -> {ns.name}")
self.state = ns
transitions = {
State.IDLE: {State.IDLE, State.WORKING},
State.WORKING: {State.WORKING, State.IDLE, State.INTERRUPTED},
State.UNAVAILABLE: {State.IDLE},
State.INTERRUPTED: {State.WORKING},
}
if state in transitions.get(self.state, {}):
transition(state)
if state == State.UNAVAILABLE:
if self.state == State.DISABLED:
logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable")
else:
logger.error(f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection")
# invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models
self.loaded_model = None
self.loaded_vae = None
transition(state)
if self.state == state_cache and self.state != state:
logger.debug(f"{self.label}: invalid transition {self.state.name} -> {state.name}")

View File

@ -18,8 +18,9 @@ from . import shared as sh
from .pmodels import ConfigModel, Benchmark_Payload
from .shared import logger, extension_path
from .worker import Worker, State
from modules.call_queue import wrap_queued_call
from modules.call_queue import wrap_queued_call, queue_lock
from modules import processing
from modules import progress
class NotBenchmarked(Exception):
@ -44,6 +45,7 @@ class Job:
self.batch_size: int = batch_size
self.complementary: bool = False
self.step_override = None
self.thread = None
def __str__(self):
prefix = ''
@ -160,12 +162,6 @@ class World:
return new
else:
for key in kwargs:
if hasattr(original, key):
# TODO only necessary because this is skipping Worker.__init__ and the pyd model is saving the state as an int instead of an actual enum
if key == 'state':
original.state = kwargs[key] if type(kwargs[key]) is State else State(kwargs[key])
continue
setattr(original, key, kwargs[key])
return original
@ -186,29 +182,22 @@ class World:
Thread(target=worker.refresh_checkpoints, args=()).start()
def sample_master(self) -> float:
# wrap our benchmark payload
master_bench_payload = StableDiffusionProcessingTxt2Img()
p = StableDiffusionProcessingTxt2Img()
d = sh.benchmark_payload.dict()
for key in d:
setattr(master_bench_payload, key, d[key])
setattr(p, key, d[key])
p.do_not_save_samples = True
# Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to.
master_bench_payload.do_not_save_samples = True
# shared.state.begin(job='distributed_master_bench')
wrapped = (wrap_queued_call(process_images))
start = time.time()
wrapped(master_bench_payload)
# wrap_gradio_gpu_call(process_images)(master_bench_payload)
# shared.state.end()
process_images(p)
return time.time() - start
def benchmark(self, rebenchmark: bool = False):
"""
Attempts to benchmark all workers a part of the world.
"""
local_task_id = 'task(distributed_bench)'
unbenched_workers = []
if rebenchmark:
for worker in self._workers:
@ -247,26 +236,42 @@ class World:
futures.clear()
# benchmark those that haven't been
if len(unbenched_workers) > 0:
queue_lock.acquire()
gradio.Info("Distributed: benchmarking in progress, please wait")
for worker in unbenched_workers:
if worker.state in (State.DISABLED, State.UNAVAILABLE):
logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark")
logger.debug(f"worker '{worker.label}' is {worker.state.name}, refusing to benchmark")
continue
if worker.model_override is not None:
logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n"
f"*all workers should be evaluated against the same model")
if worker.master:
if progress.current_task is None:
progress.add_task_to_queue(local_task_id)
progress.start_task(local_task_id)
shared.state.begin(job=local_task_id)
shared.state.job_count = sh.warmup_samples + sh.samples
chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master)
futures.append(executor.submit(chosen, worker))
logger.info(f"benchmarking worker '{worker.label}'")
if len(futures) > 0:
# wait for all benchmarks to finish and update stats on newly benchmarked workers
concurrent.futures.wait(futures)
logger.info("benchmarking finished")
# save benchmark results to workers.json
self.save_config()
if progress.current_task == local_task_id:
shared.state.end()
progress.finish_task(local_task_id)
queue_lock.release()
logger.info("benchmarking finished")
logger.info(self.speed_summary())
gradio.Info("Distributed: benchmarking complete!")
self.save_config()
def get_current_output_size(self) -> int:
"""
@ -369,7 +374,7 @@ class World:
batch_size = self.default_batch_size()
for worker in self.get_workers():
if worker.state != State.DISABLED and worker.state != State.UNAVAILABLE:
if worker.state not in (State.DISABLED, State.UNAVAILABLE):
if worker.avg_ipm is None or worker.avg_ipm <= 0:
logger.debug(f"No recorded speed for worker '{worker.label}, benchmarking'")
worker.benchmark()
@ -379,16 +384,16 @@ class World:
def update(self, p):
"""preps world for another run"""
if not self.initialized:
self.benchmark()
self.initialized = True
logger.debug("world initialized!")
else:
logger.debug("world was already initialized")
self.p = p
self.benchmark()
self.make_jobs()
if not self.initialized:
self.initialized = True
logger.debug("world initialized!")
def get_workers(self):
filtered: List[Worker] = []
for worker in self._workers:
@ -397,7 +402,7 @@ class World:
continue
if worker.master and self.thin_client_mode:
continue
if worker.state != State.UNAVAILABLE and worker.state != State.DISABLED:
if worker.state not in (State.UNAVAILABLE, State.DISABLED):
filtered.append(worker)
return filtered
@ -535,7 +540,7 @@ class World:
seconds_per_sample = job.worker.eta(payload=payload, batch_size=1, samples=1)
realtime_samples = slack_time // seconds_per_sample
logger.debug(
f"job for '{job.worker.label}' downscaled to {realtime_samples} samples to meet time constraints\n"
f"job for '{job.worker.label}' downscaled to {realtime_samples:.0f} samples to meet time constraints\n"
f"{realtime_samples:.0f} samples = {slack_time:.2f}s slack ÷ {seconds_per_sample:.2f}s/sample\n"
f" step reduction: {payload['steps']} -> {realtime_samples:.0f}"
)
@ -546,17 +551,7 @@ class World:
else:
logger.debug("complementary image production is disabled")
iterations = payload['n_iter']
num_returning = self.get_current_output_size()
num_complementary = num_returning - self.p.batch_size
distro_summary = "Job distribution:\n"
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
if num_complementary > 0:
distro_summary += f" + {num_complementary} complementary"
distro_summary += f": {num_returning} images total\n"
for job in self.jobs:
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
logger.info(distro_summary)
logger.info(self.distro_summary(payload))
if self.thin_client_mode is True or self.master_job().batch_size == 0:
# save original process_images_inner for later so we can restore once we're done
@ -581,6 +576,20 @@ class World:
del self.jobs[last]
last -= 1
def distro_summary(self, payload):
# iterations = dict(payload)['n_iter']
iterations = self.p.n_iter
num_returning = self.get_current_output_size()
num_complementary = num_returning - self.p.batch_size
distro_summary = "Job distribution:\n"
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
if num_complementary > 0:
distro_summary += f" + {num_complementary} complementary"
distro_summary += f": {num_returning} images total\n"
for job in self.jobs:
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
return distro_summary
def config(self) -> dict:
"""
{
@ -654,6 +663,10 @@ class World:
fields['label'] = label
# TODO must be overridden everytime here or later converted to a config file variable at some point
fields['verify_remotes'] = self.verify_remotes
# cast enum id to actual enum type and then prime state
fields['state'] = State(fields['state'])
if fields['state'] not in (State.DISABLED, State.UNAVAILABLE):
fields['state'] = State.IDLE
self.add_worker(**fields)
@ -663,11 +676,11 @@ class World:
self.complement_production = config.complement_production
self.step_scaling = config.step_scaling
logger.debug("config loaded")
logger.debug(f"config loaded from '{os.path.abspath(self.config_path)}'")
def save_config(self):
"""
Saves the config file.
Saves current state to the config file.
"""
config = ConfigModel(
@ -727,11 +740,14 @@ class World:
msg = f"worker '{worker.label}' is online"
logger.info(msg)
gradio.Info("Distributed: "+msg)
worker.state = State.IDLE
worker.set_state(State.IDLE, expect_cycle=True)
else:
msg = f"worker '{worker.label}' is unreachable"
logger.info(msg)
gradio.Warning("Distributed: "+msg)
worker.set_state(State.UNAVAILABLE)
self.save_config()
def restart_all(self):
for worker in self._workers: