Add button under script dropdown for interrupting remote workers.
parent
c64114806d
commit
476cd4e60f
|
|
@ -7,12 +7,13 @@ import io
|
|||
import json
|
||||
import re
|
||||
|
||||
import gradio
|
||||
|
||||
from modules import scripts, script_callbacks
|
||||
from modules import processing
|
||||
from threading import Thread
|
||||
from PIL import Image
|
||||
from typing import List
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img
|
||||
import urllib3
|
||||
import copy
|
||||
from modules.images import save_image
|
||||
|
|
@ -53,6 +54,23 @@ class Script(scripts.Script):
|
|||
# return scripts.AlwaysVisible
|
||||
return True
|
||||
|
||||
def ui(self, is_img2img):
|
||||
|
||||
with gradio.Box(): # adds padding so our components don't look out of place
|
||||
interrupt_all_btn = gradio.Button(value="Interrupt all remote workers")
|
||||
interrupt_all_btn.style(full_width=False)
|
||||
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
|
||||
|
||||
return [interrupt_all_btn]
|
||||
|
||||
# World is not constructed until the first generation job, so I use an intermediary call
|
||||
@staticmethod
|
||||
def ui_connect_interrupt_btn():
|
||||
try:
|
||||
Script.world.interrupt_remotes()
|
||||
except AttributeError:
|
||||
print("Nothing to interrupt, Distributed system not initialized")
|
||||
|
||||
@staticmethod
|
||||
def add_to_gallery(processed, p):
|
||||
"""adds generated images to the image gallery after waiting for all workers to finish"""
|
||||
|
|
@ -167,13 +185,14 @@ class Script(scripts.Script):
|
|||
# for now we may have to make redundant GET requests to check if actually successful...
|
||||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8146
|
||||
name = re.sub(r'\s?\[[^\]]*\]$', '', opts.data["sd_model_checkpoint"])
|
||||
vae = opts.data["sd_vae"]
|
||||
option_payload = {
|
||||
# "sd_model_checkpoint": opts.data["sd_model_checkpoint"],
|
||||
"sd_model_checkpoint": name,
|
||||
"sd_vae": opts.data["sd_vae"]
|
||||
"sd_vae": vae
|
||||
}
|
||||
|
||||
sync_model = False # should only really to sync models once per total job
|
||||
sync = False # should only really to sync once per job
|
||||
Script.world.optimize_jobs(payload) # optimize work assignment before dispatching
|
||||
for job in Script.world.jobs:
|
||||
if job.batch_size < 1 or job.worker.master:
|
||||
|
|
@ -182,12 +201,13 @@ class Script(scripts.Script):
|
|||
new_payload = copy.copy(payload) # prevent race condition instead of sharing the payload object
|
||||
new_payload['batch_size'] = job.batch_size
|
||||
|
||||
if job.worker.loaded_model != name:
|
||||
sync_model = True
|
||||
if job.worker.loaded_model != name or job.worker.loaded_vae != vae:
|
||||
sync = True
|
||||
job.worker.loaded_model = name
|
||||
job.worker.loaded_vae = vae
|
||||
|
||||
# print(f"requesting {new_payload['batch_size']} images from worker '{job.worker.uuid}'\n")
|
||||
t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync_model,))
|
||||
t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync,))
|
||||
|
||||
t.start()
|
||||
Script.worker_threads.append(t)
|
||||
|
|
|
|||
|
|
@ -90,6 +90,8 @@ class Worker:
|
|||
last_mpe: float = None
|
||||
response: requests.Response = None
|
||||
loaded_model: str = None
|
||||
loaded_vae: str = None
|
||||
interrupted: bool = False
|
||||
|
||||
# Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A.
|
||||
# We compare to euler a because that is what we currently benchmark each node with.
|
||||
|
|
@ -135,6 +137,7 @@ class Worker:
|
|||
self.verify_remotes = verify_remotes
|
||||
self.response_time = None
|
||||
self.loaded_model = None
|
||||
self.loaded_vae = None
|
||||
|
||||
if uuid is not None:
|
||||
self.uuid = uuid
|
||||
|
|
@ -262,7 +265,7 @@ class Worker:
|
|||
correction = eta * (self.eta_mpe() / 100)
|
||||
|
||||
if cmd_opts.distributed_debug:
|
||||
print(f"worker '{self.uuid}'s ETA was off by {correction}%")
|
||||
print(f"worker '{self.uuid}'s last ETA was off by {correction}%")
|
||||
|
||||
if correction > 0:
|
||||
eta += correction
|
||||
|
|
@ -325,7 +328,8 @@ class Worker:
|
|||
)
|
||||
self.response = response.json()
|
||||
|
||||
if self.benchmarked:
|
||||
# update list of ETA accuracy
|
||||
if self.benchmarked and not self.interrupted:
|
||||
self.response_time = time.time() - start
|
||||
variance = ((eta - self.response_time) / self.response_time) * 100
|
||||
|
||||
|
|
@ -408,6 +412,17 @@ class Worker:
|
|||
self.benchmarked = True
|
||||
return avg_ipm
|
||||
|
||||
def interrupt(self):
|
||||
response = requests.post(
|
||||
self.full_url('interrupt'),
|
||||
json={},
|
||||
verify=self.verify_remotes
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
self.interrupted = True
|
||||
if cmd_opts.distributed_debug:
|
||||
print(f"successfully interrupted worker {self.uuid}")
|
||||
|
||||
class Job:
|
||||
"""
|
||||
|
|
@ -424,6 +439,8 @@ class Job:
|
|||
self.complementary: bool = False
|
||||
|
||||
|
||||
|
||||
|
||||
class World:
|
||||
"""
|
||||
The frame or "world" which holds all workers (including the local machine).
|
||||
|
|
@ -540,6 +557,17 @@ class World:
|
|||
worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes)
|
||||
self.workers.append(worker)
|
||||
|
||||
def interrupt_remotes(self):
|
||||
threads: List[Thread] = []
|
||||
|
||||
for worker in self.workers:
|
||||
if worker.master:
|
||||
continue
|
||||
|
||||
t = Thread(target=worker.interrupt, args=())
|
||||
t.start()
|
||||
|
||||
|
||||
def benchmark(self):
|
||||
"""
|
||||
Attempts to benchmark all workers a part of the world.
|
||||
|
|
|
|||
Loading…
Reference in New Issue