From 476cd4e60f2d4ea0aec777cd2cec2b2eadb4081e Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 21 Mar 2023 15:14:10 -0500 Subject: [PATCH] Add button under script dropdown for interrupting remote workers. --- scripts/extension.py | 32 ++++++++++++++++++++++++++------ scripts/spartan/World.py | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 8 deletions(-) diff --git a/scripts/extension.py b/scripts/extension.py index 4377f86..693d578 100644 --- a/scripts/extension.py +++ b/scripts/extension.py @@ -7,12 +7,13 @@ import io import json import re +import gradio + from modules import scripts, script_callbacks from modules import processing from threading import Thread from PIL import Image from typing import List -from modules.processing import StableDiffusionProcessingTxt2Img import urllib3 import copy from modules.images import save_image @@ -53,6 +54,23 @@ class Script(scripts.Script): # return scripts.AlwaysVisible return True + def ui(self, is_img2img): + + with gradio.Box(): # adds padding so our components don't look out of place + interrupt_all_btn = gradio.Button(value="Interrupt all remote workers") + interrupt_all_btn.style(full_width=False) + interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[]) + + return [interrupt_all_btn] + + # World is not constructed until the first generation job, so I use an intermediary call + @staticmethod + def ui_connect_interrupt_btn(): + try: + Script.world.interrupt_remotes() + except AttributeError: + print("Nothing to interrupt, Distributed system not initialized") + @staticmethod def add_to_gallery(processed, p): """adds generated images to the image gallery after waiting for all workers to finish""" @@ -167,13 +185,14 @@ class Script(scripts.Script): # for now we may have to make redundant GET requests to check if actually successful... # https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8146 name = re.sub(r'\s?\[[^\]]*\]$', '', opts.data["sd_model_checkpoint"]) + vae = opts.data["sd_vae"] option_payload = { # "sd_model_checkpoint": opts.data["sd_model_checkpoint"], "sd_model_checkpoint": name, - "sd_vae": opts.data["sd_vae"] + "sd_vae": vae } - sync_model = False # should only really to sync models once per total job + sync = False # should only really to sync once per job Script.world.optimize_jobs(payload) # optimize work assignment before dispatching for job in Script.world.jobs: if job.batch_size < 1 or job.worker.master: @@ -182,12 +201,13 @@ class Script(scripts.Script): new_payload = copy.copy(payload) # prevent race condition instead of sharing the payload object new_payload['batch_size'] = job.batch_size - if job.worker.loaded_model != name: - sync_model = True + if job.worker.loaded_model != name or job.worker.loaded_vae != vae: + sync = True job.worker.loaded_model = name + job.worker.loaded_vae = vae # print(f"requesting {new_payload['batch_size']} images from worker '{job.worker.uuid}'\n") - t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync_model,)) + t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync,)) t.start() Script.worker_threads.append(t) diff --git a/scripts/spartan/World.py b/scripts/spartan/World.py index 766ae02..bd45f9a 100644 --- a/scripts/spartan/World.py +++ b/scripts/spartan/World.py @@ -90,6 +90,8 @@ class Worker: last_mpe: float = None response: requests.Response = None loaded_model: str = None + loaded_vae: str = None + interrupted: bool = False # Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A. # We compare to euler a because that is what we currently benchmark each node with. @@ -135,6 +137,7 @@ class Worker: self.verify_remotes = verify_remotes self.response_time = None self.loaded_model = None + self.loaded_vae = None if uuid is not None: self.uuid = uuid @@ -262,7 +265,7 @@ class Worker: correction = eta * (self.eta_mpe() / 100) if cmd_opts.distributed_debug: - print(f"worker '{self.uuid}'s ETA was off by {correction}%") + print(f"worker '{self.uuid}'s last ETA was off by {correction}%") if correction > 0: eta += correction @@ -325,7 +328,8 @@ class Worker: ) self.response = response.json() - if self.benchmarked: + # update list of ETA accuracy + if self.benchmarked and not self.interrupted: self.response_time = time.time() - start variance = ((eta - self.response_time) / self.response_time) * 100 @@ -408,6 +412,17 @@ class Worker: self.benchmarked = True return avg_ipm + def interrupt(self): + response = requests.post( + self.full_url('interrupt'), + json={}, + verify=self.verify_remotes + ) + + if response.status_code == 200: + self.interrupted = True + if cmd_opts.distributed_debug: + print(f"successfully interrupted worker {self.uuid}") class Job: """ @@ -424,6 +439,8 @@ class Job: self.complementary: bool = False + + class World: """ The frame or "world" which holds all workers (including the local machine). @@ -540,6 +557,17 @@ class World: worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes) self.workers.append(worker) + def interrupt_remotes(self): + threads: List[Thread] = [] + + for worker in self.workers: + if worker.master: + continue + + t = Thread(target=worker.interrupt, args=()) + t.start() + + def benchmark(self): """ Attempts to benchmark all workers a part of the world.