refactoring, state fix

master
papuSpartan 2024-08-30 11:51:46 -05:00
parent 9cd7c7c351
commit bff6d16e42
3 changed files with 40 additions and 32 deletions

View File

@ -24,7 +24,7 @@ from modules.shared import state as webui_state
from scripts.spartan.control_net import pack_control_net
from scripts.spartan.shared import logger
from scripts.spartan.ui import UI
from scripts.spartan.world import World
from scripts.spartan.world import World, State
old_sigint_handler = signal.getsignal(signal.SIGINT)
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
@ -33,7 +33,6 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM)
# noinspection PyMissingOrEmptyDocstring
class DistributedScript(scripts.Script):
# global old_sigterm_handler, old_sigterm_handler
worker_threads: List[Thread] = []
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
verify_remotes = not cmd_opts.distributed_skip_verify_remotes
master_start = None
@ -150,17 +149,19 @@ class DistributedScript(scripts.Script):
# wait for response from all workers
webui_state.textinfo = "Distributed - receiving results"
for thread in self.worker_threads:
logger.debug(f"waiting for worker thread '{thread.name}'")
thread.join()
self.worker_threads.clear()
for job in self.world.jobs:
if job.thread is None:
continue
logger.debug(f"waiting for worker thread '{job.thread.name}'")
job.thread.join()
logger.debug("all worker request threads returned")
webui_state.textinfo = "Distributed - injecting images"
# some worker which we know has a good response that we can use for generating the grid
donor_worker = None
for job in self.world.jobs:
if job.batch_size < 1 or job.worker.master:
if job.worker.response is None or job.batch_size < 1 or job.worker.master:
continue
try:
@ -304,6 +305,9 @@ class DistributedScript(scripts.Script):
return
for job in self.world.jobs:
if job.worker.state in (State.UNAVAILABLE, State.DISABLED):
continue
payload_temp = copy.copy(payload)
del payload_temp['scripts_value']
payload_temp = copy.deepcopy(payload_temp)
@ -332,11 +336,9 @@ class DistributedScript(scripts.Script):
job.worker.loaded_model = name
job.worker.loaded_vae = vae
t = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,),
job.thread = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,),
name=f"{job.worker.label}_request")
t.start()
self.worker_threads.append(t)
job.thread.start()
started_jobs.append(job)
# if master batch size was changed again due to optimization change it to the updated value
@ -358,6 +360,8 @@ class DistributedScript(scripts.Script):
# restore process_images_inner if it was monkey-patched
processing.process_images_inner = self.original_process_images_inner
# save any dangling state to prevent load_config in next iteration overwriting it
self.world.save_config()
@staticmethod
def signal_handler(sig, frame):

View File

@ -344,7 +344,7 @@ class Worker:
# remove anything that is not serializable
# s_tmax can be float('inf') which is not serializable, so we convert it to the max float value
s_tmax = payload.get('s_tmax', 0.0)
if s_tmax > 1e308:
if s_tmax is not None and s_tmax > 1e308:
payload['s_tmax'] = 1e308
# remove unserializable caches
payload.pop('cached_uc', None)
@ -490,11 +490,7 @@ class Worker:
except Exception as e:
self.set_state(State.IDLE)
if payload['batch_size'] == 0:
raise InvalidWorkerResponse("Tried to request a null amount of images")
else:
raise InvalidWorkerResponse(e)
raise InvalidWorkerResponse(e)
except requests.RequestException:
self.set_state(State.UNAVAILABLE)

View File

@ -45,6 +45,7 @@ class Job:
self.batch_size: int = batch_size
self.complementary: bool = False
self.step_override = None
self.thread = None
def __str__(self):
prefix = ''
@ -373,7 +374,7 @@ class World:
batch_size = self.default_batch_size()
for worker in self.get_workers():
if worker.state != State.DISABLED and worker.state != State.UNAVAILABLE:
if worker.state not in (State.DISABLED, State.UNAVAILABLE):
if worker.avg_ipm is None or worker.avg_ipm <= 0:
logger.debug(f"No recorded speed for worker '{worker.label}, benchmarking'")
worker.benchmark()
@ -401,7 +402,7 @@ class World:
continue
if worker.master and self.thin_client_mode:
continue
if worker.state != State.UNAVAILABLE and worker.state != State.DISABLED:
if worker.state not in (State.UNAVAILABLE, State.DISABLED):
filtered.append(worker)
return filtered
@ -550,17 +551,7 @@ class World:
else:
logger.debug("complementary image production is disabled")
iterations = payload['n_iter']
num_returning = self.get_current_output_size()
num_complementary = num_returning - self.p.batch_size
distro_summary = "Job distribution:\n"
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
if num_complementary > 0:
distro_summary += f" + {num_complementary} complementary"
distro_summary += f": {num_returning} images total\n"
for job in self.jobs:
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
logger.info(distro_summary)
logger.info(self.distro_summary(payload))
if self.thin_client_mode is True or self.master_job().batch_size == 0:
# save original process_images_inner for later so we can restore once we're done
@ -585,6 +576,20 @@ class World:
del self.jobs[last]
last -= 1
def distro_summary(self, payload):
# iterations = dict(payload)['n_iter']
iterations = self.p.n_iter
num_returning = self.get_current_output_size()
num_complementary = num_returning - self.p.batch_size
distro_summary = "Job distribution:\n"
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
if num_complementary > 0:
distro_summary += f" + {num_complementary} complementary"
distro_summary += f": {num_returning} images total\n"
for job in self.jobs:
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
return distro_summary
def config(self) -> dict:
"""
{
@ -660,8 +665,8 @@ class World:
fields['verify_remotes'] = self.verify_remotes
# cast enum id to actual enum type and then prime state
fields['state'] = State(fields['state'])
if fields['state'] != State.DISABLED:
fields['state'] = State.IDLE
if fields['state'] not in (State.DISABLED, State.UNAVAILABLE):
fields['state'] = State.IDLE
self.add_worker(**fields)
@ -740,6 +745,9 @@ class World:
msg = f"worker '{worker.label}' is unreachable"
logger.info(msg)
gradio.Warning("Distributed: "+msg)
worker.set_state(State.UNAVAILABLE)
self.save_config()
def restart_all(self):
for worker in self._workers: