Add button under script dropdown for interrupting remote workers.

pull/2/head
unknown 2023-03-21 15:14:10 -05:00
parent c64114806d
commit 476cd4e60f
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
2 changed files with 56 additions and 8 deletions

View File

@ -7,12 +7,13 @@ import io
import json
import re
import gradio
from modules import scripts, script_callbacks
from modules import processing
from threading import Thread
from PIL import Image
from typing import List
from modules.processing import StableDiffusionProcessingTxt2Img
import urllib3
import copy
from modules.images import save_image
@ -53,6 +54,23 @@ class Script(scripts.Script):
# return scripts.AlwaysVisible
return True
def ui(self, is_img2img):
with gradio.Box(): # adds padding so our components don't look out of place
interrupt_all_btn = gradio.Button(value="Interrupt all remote workers")
interrupt_all_btn.style(full_width=False)
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
return [interrupt_all_btn]
# World is not constructed until the first generation job, so I use an intermediary call
@staticmethod
def ui_connect_interrupt_btn():
try:
Script.world.interrupt_remotes()
except AttributeError:
print("Nothing to interrupt, Distributed system not initialized")
@staticmethod
def add_to_gallery(processed, p):
"""adds generated images to the image gallery after waiting for all workers to finish"""
@ -167,13 +185,14 @@ class Script(scripts.Script):
# for now we may have to make redundant GET requests to check if actually successful...
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8146
name = re.sub(r'\s?\[[^\]]*\]$', '', opts.data["sd_model_checkpoint"])
vae = opts.data["sd_vae"]
option_payload = {
# "sd_model_checkpoint": opts.data["sd_model_checkpoint"],
"sd_model_checkpoint": name,
"sd_vae": opts.data["sd_vae"]
"sd_vae": vae
}
sync_model = False # should only really to sync models once per total job
sync = False # should only really to sync once per job
Script.world.optimize_jobs(payload) # optimize work assignment before dispatching
for job in Script.world.jobs:
if job.batch_size < 1 or job.worker.master:
@ -182,12 +201,13 @@ class Script(scripts.Script):
new_payload = copy.copy(payload) # prevent race condition instead of sharing the payload object
new_payload['batch_size'] = job.batch_size
if job.worker.loaded_model != name:
sync_model = True
if job.worker.loaded_model != name or job.worker.loaded_vae != vae:
sync = True
job.worker.loaded_model = name
job.worker.loaded_vae = vae
# print(f"requesting {new_payload['batch_size']} images from worker '{job.worker.uuid}'\n")
t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync_model,))
t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync,))
t.start()
Script.worker_threads.append(t)

View File

@ -90,6 +90,8 @@ class Worker:
last_mpe: float = None
response: requests.Response = None
loaded_model: str = None
loaded_vae: str = None
interrupted: bool = False
# Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A.
# We compare to euler a because that is what we currently benchmark each node with.
@ -135,6 +137,7 @@ class Worker:
self.verify_remotes = verify_remotes
self.response_time = None
self.loaded_model = None
self.loaded_vae = None
if uuid is not None:
self.uuid = uuid
@ -262,7 +265,7 @@ class Worker:
correction = eta * (self.eta_mpe() / 100)
if cmd_opts.distributed_debug:
print(f"worker '{self.uuid}'s ETA was off by {correction}%")
print(f"worker '{self.uuid}'s last ETA was off by {correction}%")
if correction > 0:
eta += correction
@ -325,7 +328,8 @@ class Worker:
)
self.response = response.json()
if self.benchmarked:
# update list of ETA accuracy
if self.benchmarked and not self.interrupted:
self.response_time = time.time() - start
variance = ((eta - self.response_time) / self.response_time) * 100
@ -408,6 +412,17 @@ class Worker:
self.benchmarked = True
return avg_ipm
def interrupt(self):
response = requests.post(
self.full_url('interrupt'),
json={},
verify=self.verify_remotes
)
if response.status_code == 200:
self.interrupted = True
if cmd_opts.distributed_debug:
print(f"successfully interrupted worker {self.uuid}")
class Job:
"""
@ -424,6 +439,8 @@ class Job:
self.complementary: bool = False
class World:
"""
The frame or "world" which holds all workers (including the local machine).
@ -540,6 +557,17 @@ class World:
worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes)
self.workers.append(worker)
def interrupt_remotes(self):
threads: List[Thread] = []
for worker in self.workers:
if worker.master:
continue
t = Thread(target=worker.interrupt, args=())
t.start()
def benchmark(self):
"""
Attempts to benchmark all workers a part of the world.