Add button under script dropdown for interrupting remote workers.
parent
c64114806d
commit
476cd4e60f
|
|
@ -7,12 +7,13 @@ import io
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
import gradio
|
||||||
|
|
||||||
from modules import scripts, script_callbacks
|
from modules import scripts, script_callbacks
|
||||||
from modules import processing
|
from modules import processing
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import List
|
from typing import List
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img
|
|
||||||
import urllib3
|
import urllib3
|
||||||
import copy
|
import copy
|
||||||
from modules.images import save_image
|
from modules.images import save_image
|
||||||
|
|
@ -53,6 +54,23 @@ class Script(scripts.Script):
|
||||||
# return scripts.AlwaysVisible
|
# return scripts.AlwaysVisible
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def ui(self, is_img2img):
|
||||||
|
|
||||||
|
with gradio.Box(): # adds padding so our components don't look out of place
|
||||||
|
interrupt_all_btn = gradio.Button(value="Interrupt all remote workers")
|
||||||
|
interrupt_all_btn.style(full_width=False)
|
||||||
|
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
|
||||||
|
|
||||||
|
return [interrupt_all_btn]
|
||||||
|
|
||||||
|
# World is not constructed until the first generation job, so I use an intermediary call
|
||||||
|
@staticmethod
|
||||||
|
def ui_connect_interrupt_btn():
|
||||||
|
try:
|
||||||
|
Script.world.interrupt_remotes()
|
||||||
|
except AttributeError:
|
||||||
|
print("Nothing to interrupt, Distributed system not initialized")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_to_gallery(processed, p):
|
def add_to_gallery(processed, p):
|
||||||
"""adds generated images to the image gallery after waiting for all workers to finish"""
|
"""adds generated images to the image gallery after waiting for all workers to finish"""
|
||||||
|
|
@ -167,13 +185,14 @@ class Script(scripts.Script):
|
||||||
# for now we may have to make redundant GET requests to check if actually successful...
|
# for now we may have to make redundant GET requests to check if actually successful...
|
||||||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8146
|
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8146
|
||||||
name = re.sub(r'\s?\[[^\]]*\]$', '', opts.data["sd_model_checkpoint"])
|
name = re.sub(r'\s?\[[^\]]*\]$', '', opts.data["sd_model_checkpoint"])
|
||||||
|
vae = opts.data["sd_vae"]
|
||||||
option_payload = {
|
option_payload = {
|
||||||
# "sd_model_checkpoint": opts.data["sd_model_checkpoint"],
|
# "sd_model_checkpoint": opts.data["sd_model_checkpoint"],
|
||||||
"sd_model_checkpoint": name,
|
"sd_model_checkpoint": name,
|
||||||
"sd_vae": opts.data["sd_vae"]
|
"sd_vae": vae
|
||||||
}
|
}
|
||||||
|
|
||||||
sync_model = False # should only really to sync models once per total job
|
sync = False # should only really to sync once per job
|
||||||
Script.world.optimize_jobs(payload) # optimize work assignment before dispatching
|
Script.world.optimize_jobs(payload) # optimize work assignment before dispatching
|
||||||
for job in Script.world.jobs:
|
for job in Script.world.jobs:
|
||||||
if job.batch_size < 1 or job.worker.master:
|
if job.batch_size < 1 or job.worker.master:
|
||||||
|
|
@ -182,12 +201,13 @@ class Script(scripts.Script):
|
||||||
new_payload = copy.copy(payload) # prevent race condition instead of sharing the payload object
|
new_payload = copy.copy(payload) # prevent race condition instead of sharing the payload object
|
||||||
new_payload['batch_size'] = job.batch_size
|
new_payload['batch_size'] = job.batch_size
|
||||||
|
|
||||||
if job.worker.loaded_model != name:
|
if job.worker.loaded_model != name or job.worker.loaded_vae != vae:
|
||||||
sync_model = True
|
sync = True
|
||||||
job.worker.loaded_model = name
|
job.worker.loaded_model = name
|
||||||
|
job.worker.loaded_vae = vae
|
||||||
|
|
||||||
# print(f"requesting {new_payload['batch_size']} images from worker '{job.worker.uuid}'\n")
|
# print(f"requesting {new_payload['batch_size']} images from worker '{job.worker.uuid}'\n")
|
||||||
t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync_model,))
|
t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync,))
|
||||||
|
|
||||||
t.start()
|
t.start()
|
||||||
Script.worker_threads.append(t)
|
Script.worker_threads.append(t)
|
||||||
|
|
|
||||||
|
|
@ -90,6 +90,8 @@ class Worker:
|
||||||
last_mpe: float = None
|
last_mpe: float = None
|
||||||
response: requests.Response = None
|
response: requests.Response = None
|
||||||
loaded_model: str = None
|
loaded_model: str = None
|
||||||
|
loaded_vae: str = None
|
||||||
|
interrupted: bool = False
|
||||||
|
|
||||||
# Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A.
|
# Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A.
|
||||||
# We compare to euler a because that is what we currently benchmark each node with.
|
# We compare to euler a because that is what we currently benchmark each node with.
|
||||||
|
|
@ -135,6 +137,7 @@ class Worker:
|
||||||
self.verify_remotes = verify_remotes
|
self.verify_remotes = verify_remotes
|
||||||
self.response_time = None
|
self.response_time = None
|
||||||
self.loaded_model = None
|
self.loaded_model = None
|
||||||
|
self.loaded_vae = None
|
||||||
|
|
||||||
if uuid is not None:
|
if uuid is not None:
|
||||||
self.uuid = uuid
|
self.uuid = uuid
|
||||||
|
|
@ -262,7 +265,7 @@ class Worker:
|
||||||
correction = eta * (self.eta_mpe() / 100)
|
correction = eta * (self.eta_mpe() / 100)
|
||||||
|
|
||||||
if cmd_opts.distributed_debug:
|
if cmd_opts.distributed_debug:
|
||||||
print(f"worker '{self.uuid}'s ETA was off by {correction}%")
|
print(f"worker '{self.uuid}'s last ETA was off by {correction}%")
|
||||||
|
|
||||||
if correction > 0:
|
if correction > 0:
|
||||||
eta += correction
|
eta += correction
|
||||||
|
|
@ -325,7 +328,8 @@ class Worker:
|
||||||
)
|
)
|
||||||
self.response = response.json()
|
self.response = response.json()
|
||||||
|
|
||||||
if self.benchmarked:
|
# update list of ETA accuracy
|
||||||
|
if self.benchmarked and not self.interrupted:
|
||||||
self.response_time = time.time() - start
|
self.response_time = time.time() - start
|
||||||
variance = ((eta - self.response_time) / self.response_time) * 100
|
variance = ((eta - self.response_time) / self.response_time) * 100
|
||||||
|
|
||||||
|
|
@ -408,6 +412,17 @@ class Worker:
|
||||||
self.benchmarked = True
|
self.benchmarked = True
|
||||||
return avg_ipm
|
return avg_ipm
|
||||||
|
|
||||||
|
def interrupt(self):
|
||||||
|
response = requests.post(
|
||||||
|
self.full_url('interrupt'),
|
||||||
|
json={},
|
||||||
|
verify=self.verify_remotes
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
self.interrupted = True
|
||||||
|
if cmd_opts.distributed_debug:
|
||||||
|
print(f"successfully interrupted worker {self.uuid}")
|
||||||
|
|
||||||
class Job:
|
class Job:
|
||||||
"""
|
"""
|
||||||
|
|
@ -424,6 +439,8 @@ class Job:
|
||||||
self.complementary: bool = False
|
self.complementary: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class World:
|
class World:
|
||||||
"""
|
"""
|
||||||
The frame or "world" which holds all workers (including the local machine).
|
The frame or "world" which holds all workers (including the local machine).
|
||||||
|
|
@ -540,6 +557,17 @@ class World:
|
||||||
worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes)
|
worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes)
|
||||||
self.workers.append(worker)
|
self.workers.append(worker)
|
||||||
|
|
||||||
|
def interrupt_remotes(self):
|
||||||
|
threads: List[Thread] = []
|
||||||
|
|
||||||
|
for worker in self.workers:
|
||||||
|
if worker.master:
|
||||||
|
continue
|
||||||
|
|
||||||
|
t = Thread(target=worker.interrupt, args=())
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
def benchmark(self):
|
def benchmark(self):
|
||||||
"""
|
"""
|
||||||
Attempts to benchmark all workers a part of the world.
|
Attempts to benchmark all workers a part of the world.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue