Add button under script dropdown for interrupting remote workers.

pull/2/head
unknown 2023-03-21 15:14:10 -05:00
parent c64114806d
commit 476cd4e60f
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
2 changed files with 56 additions and 8 deletions

View File

@ -7,12 +7,13 @@ import io
import json import json
import re import re
import gradio
from modules import scripts, script_callbacks from modules import scripts, script_callbacks
from modules import processing from modules import processing
from threading import Thread from threading import Thread
from PIL import Image from PIL import Image
from typing import List from typing import List
from modules.processing import StableDiffusionProcessingTxt2Img
import urllib3 import urllib3
import copy import copy
from modules.images import save_image from modules.images import save_image
@ -53,6 +54,23 @@ class Script(scripts.Script):
# return scripts.AlwaysVisible # return scripts.AlwaysVisible
return True return True
def ui(self, is_img2img):
with gradio.Box(): # adds padding so our components don't look out of place
interrupt_all_btn = gradio.Button(value="Interrupt all remote workers")
interrupt_all_btn.style(full_width=False)
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
return [interrupt_all_btn]
# World is not constructed until the first generation job, so I use an intermediary call
@staticmethod
def ui_connect_interrupt_btn():
try:
Script.world.interrupt_remotes()
except AttributeError:
print("Nothing to interrupt, Distributed system not initialized")
@staticmethod @staticmethod
def add_to_gallery(processed, p): def add_to_gallery(processed, p):
"""adds generated images to the image gallery after waiting for all workers to finish""" """adds generated images to the image gallery after waiting for all workers to finish"""
@ -167,13 +185,14 @@ class Script(scripts.Script):
# for now we may have to make redundant GET requests to check if actually successful... # for now we may have to make redundant GET requests to check if actually successful...
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8146 # https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8146
name = re.sub(r'\s?\[[^\]]*\]$', '', opts.data["sd_model_checkpoint"]) name = re.sub(r'\s?\[[^\]]*\]$', '', opts.data["sd_model_checkpoint"])
vae = opts.data["sd_vae"]
option_payload = { option_payload = {
# "sd_model_checkpoint": opts.data["sd_model_checkpoint"], # "sd_model_checkpoint": opts.data["sd_model_checkpoint"],
"sd_model_checkpoint": name, "sd_model_checkpoint": name,
"sd_vae": opts.data["sd_vae"] "sd_vae": vae
} }
sync_model = False # should only really to sync models once per total job sync = False # should only really to sync once per job
Script.world.optimize_jobs(payload) # optimize work assignment before dispatching Script.world.optimize_jobs(payload) # optimize work assignment before dispatching
for job in Script.world.jobs: for job in Script.world.jobs:
if job.batch_size < 1 or job.worker.master: if job.batch_size < 1 or job.worker.master:
@ -182,12 +201,13 @@ class Script(scripts.Script):
new_payload = copy.copy(payload) # prevent race condition instead of sharing the payload object new_payload = copy.copy(payload) # prevent race condition instead of sharing the payload object
new_payload['batch_size'] = job.batch_size new_payload['batch_size'] = job.batch_size
if job.worker.loaded_model != name: if job.worker.loaded_model != name or job.worker.loaded_vae != vae:
sync_model = True sync = True
job.worker.loaded_model = name job.worker.loaded_model = name
job.worker.loaded_vae = vae
# print(f"requesting {new_payload['batch_size']} images from worker '{job.worker.uuid}'\n") # print(f"requesting {new_payload['batch_size']} images from worker '{job.worker.uuid}'\n")
t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync_model,)) t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync,))
t.start() t.start()
Script.worker_threads.append(t) Script.worker_threads.append(t)

View File

@ -90,6 +90,8 @@ class Worker:
last_mpe: float = None last_mpe: float = None
response: requests.Response = None response: requests.Response = None
loaded_model: str = None loaded_model: str = None
loaded_vae: str = None
interrupted: bool = False
# Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A. # Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A.
# We compare to euler a because that is what we currently benchmark each node with. # We compare to euler a because that is what we currently benchmark each node with.
@ -135,6 +137,7 @@ class Worker:
self.verify_remotes = verify_remotes self.verify_remotes = verify_remotes
self.response_time = None self.response_time = None
self.loaded_model = None self.loaded_model = None
self.loaded_vae = None
if uuid is not None: if uuid is not None:
self.uuid = uuid self.uuid = uuid
@ -262,7 +265,7 @@ class Worker:
correction = eta * (self.eta_mpe() / 100) correction = eta * (self.eta_mpe() / 100)
if cmd_opts.distributed_debug: if cmd_opts.distributed_debug:
print(f"worker '{self.uuid}'s ETA was off by {correction}%") print(f"worker '{self.uuid}'s last ETA was off by {correction}%")
if correction > 0: if correction > 0:
eta += correction eta += correction
@ -325,7 +328,8 @@ class Worker:
) )
self.response = response.json() self.response = response.json()
if self.benchmarked: # update list of ETA accuracy
if self.benchmarked and not self.interrupted:
self.response_time = time.time() - start self.response_time = time.time() - start
variance = ((eta - self.response_time) / self.response_time) * 100 variance = ((eta - self.response_time) / self.response_time) * 100
@ -408,6 +412,17 @@ class Worker:
self.benchmarked = True self.benchmarked = True
return avg_ipm return avg_ipm
def interrupt(self):
response = requests.post(
self.full_url('interrupt'),
json={},
verify=self.verify_remotes
)
if response.status_code == 200:
self.interrupted = True
if cmd_opts.distributed_debug:
print(f"successfully interrupted worker {self.uuid}")
class Job: class Job:
""" """
@ -424,6 +439,8 @@ class Job:
self.complementary: bool = False self.complementary: bool = False
class World: class World:
""" """
The frame or "world" which holds all workers (including the local machine). The frame or "world" which holds all workers (including the local machine).
@ -540,6 +557,17 @@ class World:
worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes) worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes)
self.workers.append(worker) self.workers.append(worker)
def interrupt_remotes(self):
threads: List[Thread] = []
for worker in self.workers:
if worker.master:
continue
t = Thread(target=worker.interrupt, args=())
t.start()
def benchmark(self): def benchmark(self):
""" """
Attempts to benchmark all workers a part of the world. Attempts to benchmark all workers a part of the world.