diff --git a/scripts/extension.py b/scripts/extension.py index f860e2b..0579550 100644 --- a/scripts/extension.py +++ b/scripts/extension.py @@ -17,6 +17,9 @@ import copy from modules.images import save_image from modules.shared import cmd_opts import time +from pathlib import Path +import os +import subprocess from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized from scripts.spartan.Worker import Worker from modules.shared import opts @@ -53,11 +56,53 @@ class Script(scripts.Script): def ui(self, is_img2img): with gradio.Box(): # adds padding so our components don't look out of place - interrupt_all_btn = gradio.Button(value="Interrupt all remote workers") - interrupt_all_btn.style(full_width=False) - interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[]) + with gradio.Accordion(label='Distributed', open=False) as main_accordian: - return [interrupt_all_btn] + with gradio.Tab('Status') as status_tab: + status = gradio.Textbox(elem_id='status', show_label=False) + status_tab.select(fn=Script.ui_connect_test, inputs=[], outputs=[status]) + + refresh_status_btn = gradio.Button(value='Refresh') + refresh_status_btn.style(size='sm') + refresh_status_btn.click(Script.ui_connect_test, inputs=[], outputs=[status]) + + with gradio.Tab('Remote Utils'): + refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints') + refresh_checkpoints_btn.style(full_width=False) + refresh_checkpoints_btn.click(Script.ui_connect_refresh_ckpts_btn, inputs=[], outputs=[]) + + sync_models_btn = gradio.Button(value='Synchronize models') + sync_models_btn.style(full_width=False) + sync_models_btn.click(Script.user_sync_script, inputs=[], outputs=[]) + + interrupt_all_btn = gradio.Button(value='Interrupt all', variant='stop') + interrupt_all_btn.style(full_width=False) + interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[]) + + return + + @staticmethod + def user_sync_script(): + user_scripts = Path(os.path.abspath(__file__)).parent.joinpath('user') + # user_script = user_scripts.joinpath('example.sh') + for file in user_scripts.iterdir(): + if file.is_file() and file.name.startswith('sync'): + user_script = file + + suffix = user_script.suffix[1:] + + if suffix == 'ps1': + subprocess.call(['powershell', user_script]) + return True + else: + f = open(user_script, 'r') + first_line = f.readline().strip() + if first_line.startswith('#!'): + shebang = first_line[2:] + subprocess.call([shebang, user_script]) + return True + + return False # World is not constructed until the first generation job, so I use an intermediary call @staticmethod @@ -67,11 +112,38 @@ class Script(scripts.Script): except AttributeError: print("Nothing to interrupt, Distributed system not initialized") + @staticmethod + def ui_connect_refresh_ckpts_btn(): + try: + Script.world.refresh_checkpoints() + except AttributeError: + print("Distributed system not initialized") + + @staticmethod + def ui_connect_test(): + try: + temp = '' + + for worker in Script.world.workers: + if worker.master: + continue + + temp += f"{worker.uuid} at {worker.address} is {worker.state.name}\n" + + return temp + + # init system if it isn't already + except AttributeError: + # batch size will be clobbered later once an actual request is made anyway so I just pass 1 + Script.initialize(initial_payload=None) + Script.ui_connect_test() + + @staticmethod def add_to_gallery(processed, p): """adds generated images to the image gallery after waiting for all workers to finish""" + # get master ipm by estimating based on worker speed - global worker master_elapsed = time.time() - Script.master_start print(f"Took master {master_elapsed}s") @@ -148,26 +220,37 @@ class Script(scripts.Script): Script.unregister_callbacks() return - def run(self, p, *args): - if cmd_opts.distributed_remotes is None: - raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)") - - Script.world = World(initial_payload=p, verify_remotes=Script.verify_remotes) - # add workers to the world - for worker in cmd_opts.distributed_remotes: - Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2]) - # register gallery callback - script_callbacks.on_after_image_processed(Script.add_to_gallery) - - if self.verify_remotes is False: + @staticmethod + def initialize(initial_payload): + if Script.verify_remotes is False: print(f"WARNING: you have chosen to forego the verification of worker TLS certificates") urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) try: - Script.world.initialize(p.batch_size) + batch_size = initial_payload.batch_size + except AttributeError: + batch_size = 1 + + Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes) + + # add workers to the world + for worker in cmd_opts.distributed_remotes: + Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2]) + + try: + Script.world.initialize(batch_size) print("World initialized!") except WorldAlreadyInitialized: - Script.world.update_world(p.batch_size) + Script.world.update_world(total_batch_size=batch_size) + + def run(self, p, *args): + if cmd_opts.distributed_remotes is None: + raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)") + + # register gallery callback + script_callbacks.on_after_image_processed(Script.add_to_gallery) + + Script.initialize(initial_payload=p) # encapsulating the request object within a txt2imgreq object is deprecated and no longer works # see test/basic_features/txt2img_test.py for an example diff --git a/scripts/spartan/Worker.py b/scripts/spartan/Worker.py index 948e3ba..4f10f6a 100644 --- a/scripts/spartan/Worker.py +++ b/scripts/spartan/Worker.py @@ -9,6 +9,7 @@ from webui import server_name from modules.shared import cmd_opts import gradio as gr from scripts.spartan.shared import benchmark_payload +from enum import Enum class InvalidWorkerResponse(Exception): @@ -18,6 +19,12 @@ class InvalidWorkerResponse(Exception): pass +class State(Enum): + IDLE = 1 + WORKING = 2 + INTERRUPTED = 3 + + class Worker: """ This class represents a worker node in a distributed computing setup. @@ -53,7 +60,7 @@ class Worker: response: requests.Response = None loaded_model: str = None loaded_vae: str = None - interrupted: bool = False + state: State = None # Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A. # We compare to euler a because that is what we currently benchmark each node with. @@ -100,6 +107,7 @@ class Worker: self.response_time = None self.loaded_model = '' self.loaded_vae = '' + self.state = State.IDLE if uuid is not None: self.uuid = uuid @@ -251,6 +259,8 @@ class Worker: # TODO detect remote out of memory exception and restart or garbage collect instance using api? try: + self.state = State.WORKING + # query memory available on worker and store for future reference if self.queried is False: self.queried = True @@ -290,7 +300,7 @@ class Worker: self.response = response.json() # update list of ETA accuracy - if self.benchmarked and not self.interrupted: + if self.benchmarked and not self.state == State.INTERRUPTED: self.response_time = time.time() - start variance = ((eta - self.response_time) / self.response_time) * 100 @@ -312,6 +322,7 @@ class Worker: except requests.exceptions.ConnectTimeout: print(f"\nTimed out waiting for worker '{self.uuid}' at {self}") + self.state = State.IDLE return def benchmark(self) -> int: @@ -342,6 +353,7 @@ class Worker: results: List[float] = [] # it's seems to be lower for the first couple of generations # TODO look into how and why this "warmup" happens + self.state = State.WORKING for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up" t = Thread(target=self.request, args=(benchmark_payload, None, False,)) try: # if the worker is unreachable/offline then handle that here @@ -372,8 +384,22 @@ class Worker: # noinspection PyTypeChecker self.response = None self.benchmarked = True + self.state = State.IDLE return avg_ipm + def refresh_checkpoints(self): + + response = requests.post( + self.full_url('refresh-checkpoints'), + json={}, + verify=self.verify_remotes + ) + + if response.status_code == 200: + self.state = State.INTERRUPTED + if cmd_opts.distributed_debug: + print(f"successfully refreshed checkpoints for worker '{self.uuid}'") + def interrupt(self): response = requests.post( self.full_url('interrupt'), @@ -382,6 +408,6 @@ class Worker: ) if response.status_code == 200: - self.interrupted = True + self.state = State.INTERRUPTED if cmd_opts.distributed_debug: print(f"successfully interrupted worker {self.uuid}") diff --git a/scripts/spartan/World.py b/scripts/spartan/World.py index 6f1ff97..f303fdc 100644 --- a/scripts/spartan/World.py +++ b/scripts/spartan/World.py @@ -178,6 +178,15 @@ class World: t = Thread(target=worker.interrupt, args=()) t.start() + def refresh_checkpoints(self): + threads: List[Thread] = [] + + for worker in self.workers: + if worker.master: + continue + + t = Thread(target=worker.refresh_checkpoints, args=()) + t.start() def benchmark(self): """ diff --git a/scripts/user/place_user_script_here.txt b/scripts/user/place_user_script_here.txt new file mode 100644 index 0000000..e69de29