diff --git a/scripts/distributed.py b/scripts/distributed.py index e5f84a9..2fe90a7 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -18,7 +18,7 @@ from PIL import Image from modules import processing from modules import scripts from modules.images import save_image -from modules.processing import fix_seed, Processed +from modules.processing import fix_seed from modules.shared import opts, cmd_opts from modules.shared import state as webui_state from scripts.spartan.control_net import pack_control_net @@ -31,33 +31,29 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM) # noinspection PyMissingOrEmptyDocstring -class Script(scripts.Script): +class DistributedScript(scripts.Script): + # global old_sigterm_handler, old_sigterm_handler + worker_threads: List[Thread] = [] + # Whether to verify worker certificates. Can be useful if your remotes are self-signed. + verify_remotes = not cmd_opts.distributed_skip_verify_remotes + master_start = None + runs_since_init = 0 + name = "distributed" + is_dropdown_handler_injected = False + if verify_remotes is False: + logger.warning(f"You have chosen to forego the verification of worker TLS certificates") + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + # build world + world = World(verify_remotes=verify_remotes) + world.load_config() + logger.info("doing initial ping sweep to see which workers are reachable") + world.ping_remotes(indiscriminate=True) + + # constructed for both txt2img and img2img def __init__(self): super().__init__() - self.worker_threads: List[Thread] = [] - # Whether to verify worker certificates. Can be useful if your remotes are self-signed. - self.verify_remotes = not cmd_opts.distributed_skip_verify_remotes - self.is_img2img = True - self.is_txt2img = True - self.alwayson = True - self.master_start = None - self.runs_since_init = 0 - self.name = "distributed" - self.is_dropdown_handler_injected = False - - if self.verify_remotes is False: - logger.warning(f"You have chosen to forego the verification of worker TLS certificates") - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - - # build world - self.world = World(verify_remotes=self.verify_remotes) - self.world.load_config() - logger.info("doing initial ping sweep to see which workers are reachable") - self.world.ping_remotes(indiscriminate=True) - - signal.signal(signal.SIGINT, self.signal_handler) - signal.signal(signal.SIGTERM, self.signal_handler) def title(self): return "Distribute" @@ -67,7 +63,7 @@ class Script(scripts.Script): def ui(self, is_img2img): self.world.load_config() - extension_ui = UI(script=Script, world=self.world) + extension_ui = UI(world=self.world) # root, api_exposed = extension_ui.create_ui() components = extension_ui.create_ui() @@ -203,7 +199,7 @@ class Script(scripts.Script): # generate and inject grid if opts.return_grid: - grid = processing.images.image_grid(processed.images, len(processed.images)) + grid = images.image_grid(processed.images, len(processed.images)) processed_inject_image( image=grid, info_index=0, @@ -219,7 +215,6 @@ class Script(scripts.Script): p.batch_size = len(processed.images) return - # p's type is # "modules.processing.StableDiffusionProcessing*" def before_process(self, p, *args): @@ -365,10 +360,11 @@ class Script(scripts.Script): # restore process_images_inner if it was monkey-patched processing.process_images_inner = self.original_process_images_inner - def signal_handler(self, sig, frame): + @staticmethod + def signal_handler(sig, frame): logger.debug("handling interrupt signal") # do cleanup - self.world.save_config() + DistributedScript.world.save_config() if sig == signal.SIGINT: if callable(old_sigint_handler): @@ -378,3 +374,6 @@ class Script(scripts.Script): old_sigterm_handler(sig, frame) else: sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) diff --git a/scripts/spartan/ui.py b/scripts/spartan/ui.py index d0f4d27..cd0b87a 100644 --- a/scripts/spartan/ui.py +++ b/scripts/spartan/ui.py @@ -14,8 +14,7 @@ worker_select_dropdown = None class UI: """extension user interface related things""" - def __init__(self, script, world): - self.script = script + def __init__(self, world): self.world = world self.original_model_dropdown_handler = opts.data_labels.get('sd_model_checkpoint').onchange diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 1b2bfcb..2c72a28 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -16,7 +16,7 @@ import modules.shared as shared from modules.processing import process_images, StableDiffusionProcessingTxt2Img from . import shared as sh from .pmodels import ConfigModel, Benchmark_Payload -from .shared import logger, warmup_samples, extension_path +from .shared import logger, extension_path from .worker import Worker, State from modules.call_queue import wrap_queued_call from modules import processing