refactoring, state fix
parent
9cd7c7c351
commit
bff6d16e42
|
|
@ -24,7 +24,7 @@ from modules.shared import state as webui_state
|
||||||
from scripts.spartan.control_net import pack_control_net
|
from scripts.spartan.control_net import pack_control_net
|
||||||
from scripts.spartan.shared import logger
|
from scripts.spartan.shared import logger
|
||||||
from scripts.spartan.ui import UI
|
from scripts.spartan.ui import UI
|
||||||
from scripts.spartan.world import World
|
from scripts.spartan.world import World, State
|
||||||
|
|
||||||
old_sigint_handler = signal.getsignal(signal.SIGINT)
|
old_sigint_handler = signal.getsignal(signal.SIGINT)
|
||||||
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
||||||
|
|
@ -33,7 +33,6 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
||||||
# noinspection PyMissingOrEmptyDocstring
|
# noinspection PyMissingOrEmptyDocstring
|
||||||
class DistributedScript(scripts.Script):
|
class DistributedScript(scripts.Script):
|
||||||
# global old_sigterm_handler, old_sigterm_handler
|
# global old_sigterm_handler, old_sigterm_handler
|
||||||
worker_threads: List[Thread] = []
|
|
||||||
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
|
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
|
||||||
verify_remotes = not cmd_opts.distributed_skip_verify_remotes
|
verify_remotes = not cmd_opts.distributed_skip_verify_remotes
|
||||||
master_start = None
|
master_start = None
|
||||||
|
|
@ -150,17 +149,19 @@ class DistributedScript(scripts.Script):
|
||||||
|
|
||||||
# wait for response from all workers
|
# wait for response from all workers
|
||||||
webui_state.textinfo = "Distributed - receiving results"
|
webui_state.textinfo = "Distributed - receiving results"
|
||||||
for thread in self.worker_threads:
|
for job in self.world.jobs:
|
||||||
logger.debug(f"waiting for worker thread '{thread.name}'")
|
if job.thread is None:
|
||||||
thread.join()
|
continue
|
||||||
self.worker_threads.clear()
|
|
||||||
|
logger.debug(f"waiting for worker thread '{job.thread.name}'")
|
||||||
|
job.thread.join()
|
||||||
logger.debug("all worker request threads returned")
|
logger.debug("all worker request threads returned")
|
||||||
webui_state.textinfo = "Distributed - injecting images"
|
webui_state.textinfo = "Distributed - injecting images"
|
||||||
|
|
||||||
# some worker which we know has a good response that we can use for generating the grid
|
# some worker which we know has a good response that we can use for generating the grid
|
||||||
donor_worker = None
|
donor_worker = None
|
||||||
for job in self.world.jobs:
|
for job in self.world.jobs:
|
||||||
if job.batch_size < 1 or job.worker.master:
|
if job.worker.response is None or job.batch_size < 1 or job.worker.master:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -304,6 +305,9 @@ class DistributedScript(scripts.Script):
|
||||||
return
|
return
|
||||||
|
|
||||||
for job in self.world.jobs:
|
for job in self.world.jobs:
|
||||||
|
if job.worker.state in (State.UNAVAILABLE, State.DISABLED):
|
||||||
|
continue
|
||||||
|
|
||||||
payload_temp = copy.copy(payload)
|
payload_temp = copy.copy(payload)
|
||||||
del payload_temp['scripts_value']
|
del payload_temp['scripts_value']
|
||||||
payload_temp = copy.deepcopy(payload_temp)
|
payload_temp = copy.deepcopy(payload_temp)
|
||||||
|
|
@ -332,11 +336,9 @@ class DistributedScript(scripts.Script):
|
||||||
job.worker.loaded_model = name
|
job.worker.loaded_model = name
|
||||||
job.worker.loaded_vae = vae
|
job.worker.loaded_vae = vae
|
||||||
|
|
||||||
t = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,),
|
job.thread = Thread(target=job.worker.request, args=(payload_temp, option_payload, sync,),
|
||||||
name=f"{job.worker.label}_request")
|
name=f"{job.worker.label}_request")
|
||||||
|
job.thread.start()
|
||||||
t.start()
|
|
||||||
self.worker_threads.append(t)
|
|
||||||
started_jobs.append(job)
|
started_jobs.append(job)
|
||||||
|
|
||||||
# if master batch size was changed again due to optimization change it to the updated value
|
# if master batch size was changed again due to optimization change it to the updated value
|
||||||
|
|
@ -358,6 +360,8 @@ class DistributedScript(scripts.Script):
|
||||||
|
|
||||||
# restore process_images_inner if it was monkey-patched
|
# restore process_images_inner if it was monkey-patched
|
||||||
processing.process_images_inner = self.original_process_images_inner
|
processing.process_images_inner = self.original_process_images_inner
|
||||||
|
# save any dangling state to prevent load_config in next iteration overwriting it
|
||||||
|
self.world.save_config()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def signal_handler(sig, frame):
|
def signal_handler(sig, frame):
|
||||||
|
|
|
||||||
|
|
@ -344,7 +344,7 @@ class Worker:
|
||||||
# remove anything that is not serializable
|
# remove anything that is not serializable
|
||||||
# s_tmax can be float('inf') which is not serializable, so we convert it to the max float value
|
# s_tmax can be float('inf') which is not serializable, so we convert it to the max float value
|
||||||
s_tmax = payload.get('s_tmax', 0.0)
|
s_tmax = payload.get('s_tmax', 0.0)
|
||||||
if s_tmax > 1e308:
|
if s_tmax is not None and s_tmax > 1e308:
|
||||||
payload['s_tmax'] = 1e308
|
payload['s_tmax'] = 1e308
|
||||||
# remove unserializable caches
|
# remove unserializable caches
|
||||||
payload.pop('cached_uc', None)
|
payload.pop('cached_uc', None)
|
||||||
|
|
@ -490,10 +490,6 @@ class Worker:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.set_state(State.IDLE)
|
self.set_state(State.IDLE)
|
||||||
|
|
||||||
if payload['batch_size'] == 0:
|
|
||||||
raise InvalidWorkerResponse("Tried to request a null amount of images")
|
|
||||||
else:
|
|
||||||
raise InvalidWorkerResponse(e)
|
raise InvalidWorkerResponse(e)
|
||||||
|
|
||||||
except requests.RequestException:
|
except requests.RequestException:
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ class Job:
|
||||||
self.batch_size: int = batch_size
|
self.batch_size: int = batch_size
|
||||||
self.complementary: bool = False
|
self.complementary: bool = False
|
||||||
self.step_override = None
|
self.step_override = None
|
||||||
|
self.thread = None
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
prefix = ''
|
prefix = ''
|
||||||
|
|
@ -373,7 +374,7 @@ class World:
|
||||||
|
|
||||||
batch_size = self.default_batch_size()
|
batch_size = self.default_batch_size()
|
||||||
for worker in self.get_workers():
|
for worker in self.get_workers():
|
||||||
if worker.state != State.DISABLED and worker.state != State.UNAVAILABLE:
|
if worker.state not in (State.DISABLED, State.UNAVAILABLE):
|
||||||
if worker.avg_ipm is None or worker.avg_ipm <= 0:
|
if worker.avg_ipm is None or worker.avg_ipm <= 0:
|
||||||
logger.debug(f"No recorded speed for worker '{worker.label}, benchmarking'")
|
logger.debug(f"No recorded speed for worker '{worker.label}, benchmarking'")
|
||||||
worker.benchmark()
|
worker.benchmark()
|
||||||
|
|
@ -401,7 +402,7 @@ class World:
|
||||||
continue
|
continue
|
||||||
if worker.master and self.thin_client_mode:
|
if worker.master and self.thin_client_mode:
|
||||||
continue
|
continue
|
||||||
if worker.state != State.UNAVAILABLE and worker.state != State.DISABLED:
|
if worker.state not in (State.UNAVAILABLE, State.DISABLED):
|
||||||
filtered.append(worker)
|
filtered.append(worker)
|
||||||
|
|
||||||
return filtered
|
return filtered
|
||||||
|
|
@ -550,17 +551,7 @@ class World:
|
||||||
else:
|
else:
|
||||||
logger.debug("complementary image production is disabled")
|
logger.debug("complementary image production is disabled")
|
||||||
|
|
||||||
iterations = payload['n_iter']
|
logger.info(self.distro_summary(payload))
|
||||||
num_returning = self.get_current_output_size()
|
|
||||||
num_complementary = num_returning - self.p.batch_size
|
|
||||||
distro_summary = "Job distribution:\n"
|
|
||||||
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
|
|
||||||
if num_complementary > 0:
|
|
||||||
distro_summary += f" + {num_complementary} complementary"
|
|
||||||
distro_summary += f": {num_returning} images total\n"
|
|
||||||
for job in self.jobs:
|
|
||||||
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
|
|
||||||
logger.info(distro_summary)
|
|
||||||
|
|
||||||
if self.thin_client_mode is True or self.master_job().batch_size == 0:
|
if self.thin_client_mode is True or self.master_job().batch_size == 0:
|
||||||
# save original process_images_inner for later so we can restore once we're done
|
# save original process_images_inner for later so we can restore once we're done
|
||||||
|
|
@ -585,6 +576,20 @@ class World:
|
||||||
del self.jobs[last]
|
del self.jobs[last]
|
||||||
last -= 1
|
last -= 1
|
||||||
|
|
||||||
|
def distro_summary(self, payload):
|
||||||
|
# iterations = dict(payload)['n_iter']
|
||||||
|
iterations = self.p.n_iter
|
||||||
|
num_returning = self.get_current_output_size()
|
||||||
|
num_complementary = num_returning - self.p.batch_size
|
||||||
|
distro_summary = "Job distribution:\n"
|
||||||
|
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
|
||||||
|
if num_complementary > 0:
|
||||||
|
distro_summary += f" + {num_complementary} complementary"
|
||||||
|
distro_summary += f": {num_returning} images total\n"
|
||||||
|
for job in self.jobs:
|
||||||
|
distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n"
|
||||||
|
return distro_summary
|
||||||
|
|
||||||
def config(self) -> dict:
|
def config(self) -> dict:
|
||||||
"""
|
"""
|
||||||
{
|
{
|
||||||
|
|
@ -660,7 +665,7 @@ class World:
|
||||||
fields['verify_remotes'] = self.verify_remotes
|
fields['verify_remotes'] = self.verify_remotes
|
||||||
# cast enum id to actual enum type and then prime state
|
# cast enum id to actual enum type and then prime state
|
||||||
fields['state'] = State(fields['state'])
|
fields['state'] = State(fields['state'])
|
||||||
if fields['state'] != State.DISABLED:
|
if fields['state'] not in (State.DISABLED, State.UNAVAILABLE):
|
||||||
fields['state'] = State.IDLE
|
fields['state'] = State.IDLE
|
||||||
|
|
||||||
self.add_worker(**fields)
|
self.add_worker(**fields)
|
||||||
|
|
@ -740,6 +745,9 @@ class World:
|
||||||
msg = f"worker '{worker.label}' is unreachable"
|
msg = f"worker '{worker.label}' is unreachable"
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
gradio.Warning("Distributed: "+msg)
|
gradio.Warning("Distributed: "+msg)
|
||||||
|
worker.set_state(State.UNAVAILABLE)
|
||||||
|
|
||||||
|
self.save_config()
|
||||||
|
|
||||||
def restart_all(self):
|
def restart_all(self):
|
||||||
for worker in self._workers:
|
for worker in self._workers:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue