refactoring, state fix
parent
9cd7c7c351
commit
bff6d16e42
|
|
@ -24,7 +24,7 @@ from modules.shared import state as webui_state
|
|||
from scripts.spartan.control_net import pack_control_net
|
||||
from scripts.spartan.shared import logger
|
||||
from scripts.spartan.ui import UI
|
||||
from scripts.spartan.world import World
|
||||
from scripts.spartan.world import World, State
|
||||
|
||||
old_sigint_handler = signal.getsignal(signal.SIGINT)
|
||||
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
||||
|
|
@ -33,7 +33,6 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
|||
# noinspection PyMissingOrEmptyDocstring
|
||||
class DistributedScript(scripts.Script):
|
||||
# global old_sigterm_handler, old_sigterm_handler
|
||||
worker_threads: List[Thread] = []
|
||||
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
|
||||
verify_remotes = not cmd_opts.distributed_skip_verify_remotes
|
||||
master_start = None
|
||||
|
|
@ -150,17 +149,19 @@ class DistributedScript(scripts.Script):
|
|||
|
||||
# wait for response from all workers
|
||||
webui_state.textinfo = "Distributed - receiving results"
|
||||
for thread in self.worker_threads:
|
||||
logger.debug(f"waiting for worker thread '{thread.name}'")
|
||||
thread.join()
|
||||
self.worker_threads.clear()
|
||||
for job in self.world.jobs:
|
||||
if job.thread is None:
|
||||
continue
|
||||
|
||||
logger.debug(f"waiting for worker thread '{job.thread.name}'")
|
||||
job.thread.join()
|
||||
logger.debug("all worker request threads returned")
|
||||
webui_state.textinfo = "Distributed - injecting images"
|
||||
|
||||
# some worker which we know has a good response that we can use for generating the grid
|
||||
donor_worker = None
|
||||
for job in self.world.jobs:
|
||||
if job.batch_size < 1 or job.worker.master:
|
||||
if job.worker.response is None or job.batch_size < 1 or job.worker.master:
|
||||
continue
|
||||
|
||||
try:
|
||||
|
|
@ -304,6 +305,9 @@ class DistributedScript(scripts.Script):
|
|||
return
|
||||
|
||||
for job in self.world.jobs:
|
||||
if job.worker.state in (State.UNAVAILABLE, State.DISABLED):
|
||||
continue
|
||||
|
||||
payload_temp = copy.copy(payload)
|
||||
del payload_temp['scripts_value']
|
||||
payload_temp = copy.deepcopy(payload_temp)
|
||||
|
|
@ -332,11 +336,9 @@ class DistributedScript(scripts.Script):
|
|||
job.worker.loaded_model = name
|
||||
job.worker.loaded_vae = vae
|
||||
|
||||
t = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,),
|
||||
job.thread = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,),
|
||||
name=f"{job.worker.label}_request")
|
||||
|
||||
t.start()
|
||||
self.worker_threads.append(t)
|
||||
job.thread.start()
|
||||
started_jobs.append(job)
|
||||
|
||||
# if master batch size was changed again due to optimization change it to the updated value
|
||||
|
|
@ -358,6 +360,8 @@ class DistributedScript(scripts.Script):
|
|||
|
||||
# restore process_images_inner if it was monkey-patched
|
||||
processing.process_images_inner = self.original_process_images_inner
|
||||
# save any dangling state to prevent load_config in next iteration overwriting it
|
||||
self.world.save_config()
|
||||
|
||||
@staticmethod
|
||||
def signal_handler(sig, frame):
|
||||
|
|
|
|||
|
|
@ -344,7 +344,7 @@ class Worker:
|
|||
# remove anything that is not serializable
|
||||
# s_tmax can be float('inf') which is not serializable, so we convert it to the max float value
|
||||
s_tmax = payload.get('s_tmax', 0.0)
|
||||
if s_tmax > 1e308:
|
||||
if s_tmax is not None and s_tmax > 1e308:
|
||||
payload['s_tmax'] = 1e308
|
||||
# remove unserializable caches
|
||||
payload.pop('cached_uc', None)
|
||||
|
|
@ -490,11 +490,7 @@ class Worker:
|
|||
|
||||
except Exception as e:
|
||||
self.set_state(State.IDLE)
|
||||
|
||||
if payload['batch_size'] == 0:
|
||||
raise InvalidWorkerResponse("Tried to request a null amount of images")
|
||||
else:
|
||||
raise InvalidWorkerResponse(e)
|
||||
raise InvalidWorkerResponse(e)
|
||||
|
||||
except requests.RequestException:
|
||||
self.set_state(State.UNAVAILABLE)
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ class Job:
|
|||
self.batch_size: int = batch_size
|
||||
self.complementary: bool = False
|
||||
self.step_override = None
|
||||
self.thread = None
|
||||
|
||||
def __str__(self):
|
||||
prefix = ''
|
||||
|
|
@ -373,7 +374,7 @@ class World:
|
|||
|
||||
batch_size = self.default_batch_size()
|
||||
for worker in self.get_workers():
|
||||
if worker.state != State.DISABLED and worker.state != State.UNAVAILABLE:
|
||||
if worker.state not in (State.DISABLED, State.UNAVAILABLE):
|
||||
if worker.avg_ipm is None or worker.avg_ipm <= 0:
|
||||
logger.debug(f"No recorded speed for worker '{worker.label}, benchmarking'")
|
||||
worker.benchmark()
|
||||
|
|
@ -401,7 +402,7 @@ class World:
|
|||
continue
|
||||
if worker.master and self.thin_client_mode:
|
||||
continue
|
||||
if worker.state != State.UNAVAILABLE and worker.state != State.DISABLED:
|
||||
if worker.state not in (State.UNAVAILABLE, State.DISABLED):
|
||||
filtered.append(worker)
|
||||
|
||||
return filtered
|
||||
|
|
@ -550,17 +551,7 @@ class World:
|
|||
else:
|
||||
logger.debug("complementary image production is disabled")
|
||||
|
||||
iterations = payload['n_iter']
|
||||
num_returning = self.get_current_output_size()
|
||||
num_complementary = num_returning - self.p.batch_size
|
||||
distro_summary = "Job distribution:\n"
|
||||
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
|
||||
if num_complementary > 0:
|
||||
distro_summary += f" + {num_complementary} complementary"
|
||||
distro_summary += f": {num_returning} images total\n"
|
||||
for job in self.jobs:
|
||||
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
|
||||
logger.info(distro_summary)
|
||||
logger.info(self.distro_summary(payload))
|
||||
|
||||
if self.thin_client_mode is True or self.master_job().batch_size == 0:
|
||||
# save original process_images_inner for later so we can restore once we're done
|
||||
|
|
@ -585,6 +576,20 @@ class World:
|
|||
del self.jobs[last]
|
||||
last -= 1
|
||||
|
||||
def distro_summary(self, payload):
|
||||
# iterations = dict(payload)['n_iter']
|
||||
iterations = self.p.n_iter
|
||||
num_returning = self.get_current_output_size()
|
||||
num_complementary = num_returning - self.p.batch_size
|
||||
distro_summary = "Job distribution:\n"
|
||||
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
|
||||
if num_complementary > 0:
|
||||
distro_summary += f" + {num_complementary} complementary"
|
||||
distro_summary += f": {num_returning} images total\n"
|
||||
for job in self.jobs:
|
||||
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
|
||||
return distro_summary
|
||||
|
||||
def config(self) -> dict:
|
||||
"""
|
||||
{
|
||||
|
|
@ -660,8 +665,8 @@ class World:
|
|||
fields['verify_remotes'] = self.verify_remotes
|
||||
# cast enum id to actual enum type and then prime state
|
||||
fields['state'] = State(fields['state'])
|
||||
if fields['state'] != State.DISABLED:
|
||||
fields['state'] = State.IDLE
|
||||
if fields['state'] not in (State.DISABLED, State.UNAVAILABLE):
|
||||
fields['state'] = State.IDLE
|
||||
|
||||
self.add_worker(**fields)
|
||||
|
||||
|
|
@ -740,6 +745,9 @@ class World:
|
|||
msg = f"worker '{worker.label}' is unreachable"
|
||||
logger.info(msg)
|
||||
gradio.Warning("Distributed: "+msg)
|
||||
worker.set_state(State.UNAVAILABLE)
|
||||
|
||||
self.save_config()
|
||||
|
||||
def restart_all(self):
|
||||
for worker in self._workers:
|
||||
|
|
|
|||
Loading…
Reference in New Issue