refactoring
parent
de18217aaf
commit
0ac23cbd73
|
|
@ -32,27 +32,32 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
|||
|
||||
# noinspection PyMissingOrEmptyDocstring
|
||||
class Script(scripts.Script):
|
||||
worker_threads: List[Thread] = []
|
||||
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
|
||||
verify_remotes = not cmd_opts.distributed_skip_verify_remotes
|
||||
|
||||
is_img2img = True
|
||||
is_txt2img = True
|
||||
alwayson = True
|
||||
master_start = None
|
||||
runs_since_init = 0
|
||||
name = "distributed"
|
||||
is_dropdown_handler_injected = False
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.worker_threads: List[Thread] = []
|
||||
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
|
||||
self.verify_remotes = not cmd_opts.distributed_skip_verify_remotes
|
||||
self.is_img2img = True
|
||||
self.is_txt2img = True
|
||||
self.alwayson = True
|
||||
self.master_start = None
|
||||
self.runs_since_init = 0
|
||||
self.name = "distributed"
|
||||
self.is_dropdown_handler_injected = False
|
||||
|
||||
if verify_remotes is False:
|
||||
logger.warning(f"You have chosen to forego the verification of worker TLS certificates")
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
if self.verify_remotes is False:
|
||||
logger.warning(f"You have chosen to forego the verification of worker TLS certificates")
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
# build world
|
||||
world = World(verify_remotes=verify_remotes)
|
||||
world.load_config()
|
||||
logger.info("doing initial ping sweep to see which workers are reachable")
|
||||
world.ping_remotes(indiscriminate=True)
|
||||
# build world
|
||||
self.world = World(verify_remotes=self.verify_remotes)
|
||||
self.world.load_config()
|
||||
logger.info("doing initial ping sweep to see which workers are reachable")
|
||||
self.world.ping_remotes(indiscriminate=True)
|
||||
|
||||
signal.signal(signal.SIGINT, self.signal_handler)
|
||||
signal.signal(signal.SIGTERM, self.signal_handler)
|
||||
|
||||
def title(self):
|
||||
return "Distribute"
|
||||
|
|
@ -61,19 +66,18 @@ class Script(scripts.Script):
|
|||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
Script.world.load_config()
|
||||
extension_ui = UI(script=Script, world=Script.world)
|
||||
self.world.load_config()
|
||||
extension_ui = UI(script=Script, world=self.world)
|
||||
# root, api_exposed = extension_ui.create_ui()
|
||||
components = extension_ui.create_ui()
|
||||
|
||||
# The first injection of handler for the models dropdown(sd_model_checkpoint) which is often present
|
||||
# in the quick-settings bar of a user. Helps ensure model swaps propagate to all nodes ASAP.
|
||||
Script.world.inject_model_dropdown_handler()
|
||||
self.world.inject_model_dropdown_handler()
|
||||
# return some components that should be exposed to the api
|
||||
return components
|
||||
|
||||
@staticmethod
|
||||
def add_to_gallery(processed, p):
|
||||
def add_to_gallery(self, processed, p):
|
||||
"""adds generated images to the image gallery after waiting for all workers to finish"""
|
||||
|
||||
def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None):
|
||||
|
|
@ -118,10 +122,10 @@ class Script(scripts.Script):
|
|||
if p.n_iter > 1: # if splitting by batch count
|
||||
num_remote_images *= p.n_iter - 1
|
||||
|
||||
logger.debug(f"image {true_image_pos + 1}/{Script.world.p.batch_size * p.n_iter}, "
|
||||
logger.debug(f"image {true_image_pos + 1}/{self.world.p.batch_size * p.n_iter}, "
|
||||
f"info-index: {info_index}")
|
||||
|
||||
if Script.world.thin_client_mode:
|
||||
if self.world.thin_client_mode:
|
||||
p.all_negative_prompts = processed.all_negative_prompts
|
||||
|
||||
try:
|
||||
|
|
@ -146,21 +150,21 @@ class Script(scripts.Script):
|
|||
)
|
||||
|
||||
# get master ipm by estimating based on worker speed
|
||||
master_elapsed = time.time() - Script.master_start
|
||||
master_elapsed = time.time() - self.master_start
|
||||
logger.debug(f"Took master {master_elapsed:.2f}s")
|
||||
|
||||
# wait for response from all workers
|
||||
webui_state.textinfo = "Distributed - receiving results"
|
||||
for thread in Script.worker_threads:
|
||||
for thread in self.worker_threads:
|
||||
logger.debug(f"waiting for worker thread '{thread.name}'")
|
||||
thread.join()
|
||||
Script.worker_threads.clear()
|
||||
self.worker_threads.clear()
|
||||
logger.debug("all worker request threads returned")
|
||||
webui_state.textinfo = "Distributed - injecting images"
|
||||
|
||||
# some worker which we know has a good response that we can use for generating the grid
|
||||
donor_worker = None
|
||||
for job in Script.world.jobs:
|
||||
for job in self.world.jobs:
|
||||
if job.batch_size < 1 or job.worker.master:
|
||||
continue
|
||||
|
||||
|
|
@ -209,7 +213,7 @@ class Script(scripts.Script):
|
|||
)
|
||||
|
||||
# cleanup after we're doing using all the responses
|
||||
for worker in Script.world.get_workers():
|
||||
for worker in self.world.get_workers():
|
||||
worker.response = None
|
||||
|
||||
p.batch_size = len(processed.images)
|
||||
|
|
@ -218,15 +222,14 @@ class Script(scripts.Script):
|
|||
|
||||
# p's type is
|
||||
# "modules.processing.StableDiffusionProcessing*"
|
||||
@staticmethod
|
||||
def before_process(p, *args):
|
||||
if not Script.world.enabled:
|
||||
def before_process(self, p, *args):
|
||||
if not self.world.enabled:
|
||||
logger.debug("extension is disabled")
|
||||
return
|
||||
Script.world.update(p)
|
||||
self.world.update(p)
|
||||
|
||||
# save original process_images_inner function for later if we monkeypatch it
|
||||
Script.original_process_images_inner = processing.process_images_inner
|
||||
self.original_process_images_inner = processing.process_images_inner
|
||||
|
||||
# strip scripts that aren't yet supported and warn user
|
||||
packed_script_args: List[dict] = [] # list of api formatted per-script argument objects
|
||||
|
|
@ -262,7 +265,7 @@ class Script(scripts.Script):
|
|||
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works
|
||||
# see test/basic_features/txt2img_test.py for an example
|
||||
payload = copy.copy(p.__dict__)
|
||||
payload['batch_size'] = Script.world.default_batch_size()
|
||||
payload['batch_size'] = self.world.default_batch_size()
|
||||
payload['scripts'] = None
|
||||
try:
|
||||
del payload['script_args']
|
||||
|
|
@ -291,11 +294,11 @@ class Script(scripts.Script):
|
|||
|
||||
# start generating images assigned to remote machines
|
||||
sync = False # should only really need to sync once per job
|
||||
Script.world.optimize_jobs(payload) # optimize work assignment before dispatching
|
||||
self.world.optimize_jobs(payload) # optimize work assignment before dispatching
|
||||
started_jobs = []
|
||||
|
||||
# check if anything even needs to be done
|
||||
if len(Script.world.jobs) == 1 and Script.world.jobs[0].worker.master:
|
||||
if len(self.world.jobs) == 1 and self.world.jobs[0].worker.master:
|
||||
|
||||
if payload['batch_size'] >= 2:
|
||||
msg = f"all remote workers are offline or unreachable"
|
||||
|
|
@ -306,7 +309,7 @@ class Script(scripts.Script):
|
|||
|
||||
return
|
||||
|
||||
for job in Script.world.jobs:
|
||||
for job in self.world.jobs:
|
||||
payload_temp = copy.copy(payload)
|
||||
del payload_temp['scripts_value']
|
||||
payload_temp = copy.deepcopy(payload_temp)
|
||||
|
|
@ -338,35 +341,33 @@ class Script(scripts.Script):
|
|||
name=f"{job.worker.label}_request")
|
||||
|
||||
t.start()
|
||||
Script.worker_threads.append(t)
|
||||
self.worker_threads.append(t)
|
||||
started_jobs.append(job)
|
||||
|
||||
# if master batch size was changed again due to optimization change it to the updated value
|
||||
if not Script.world.thin_client_mode:
|
||||
p.batch_size = Script.world.master_job().batch_size
|
||||
Script.master_start = time.time()
|
||||
if not self.world.thin_client_mode:
|
||||
p.batch_size = self.world.master_job().batch_size
|
||||
self.master_start = time.time()
|
||||
|
||||
# generate images assigned to local machine
|
||||
p.do_not_save_grid = True # don't generate grid from master as we are doing this later.
|
||||
Script.runs_since_init += 1
|
||||
self.runs_since_init += 1
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def postprocess(p, processed, *args):
|
||||
if not Script.world.enabled:
|
||||
def postprocess(self, p, processed, *args):
|
||||
if not self.world.enabled:
|
||||
return
|
||||
|
||||
if Script.master_start is not None:
|
||||
Script.add_to_gallery(p=p, processed=processed)
|
||||
if self.master_start is not None:
|
||||
self.add_to_gallery(p=p, processed=processed)
|
||||
|
||||
# restore process_images_inner if it was monkey-patched
|
||||
processing.process_images_inner = Script.original_process_images_inner
|
||||
processing.process_images_inner = self.original_process_images_inner
|
||||
|
||||
@staticmethod
|
||||
def signal_handler(sig, frame):
|
||||
def signal_handler(self, sig, frame):
|
||||
logger.debug("handling interrupt signal")
|
||||
# do cleanup
|
||||
Script.world.save_config()
|
||||
self.world.save_config()
|
||||
|
||||
if sig == signal.SIGINT:
|
||||
if callable(old_sigint_handler):
|
||||
|
|
@ -376,6 +377,3 @@ class Script(scripts.Script):
|
|||
old_sigterm_handler(sig, frame)
|
||||
else:
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
|
|
|||
Loading…
Reference in New Issue