add buttons for refreshing checkpoints, running synchronization script. add tab for showing state of workers
parent
162b541d4e
commit
6ffef54e36
|
|
@ -17,6 +17,9 @@ import copy
|
|||
from modules.images import save_image
|
||||
from modules.shared import cmd_opts
|
||||
import time
|
||||
from pathlib import Path
|
||||
import os
|
||||
import subprocess
|
||||
from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
|
||||
from scripts.spartan.Worker import Worker
|
||||
from modules.shared import opts
|
||||
|
|
@ -53,11 +56,53 @@ class Script(scripts.Script):
|
|||
def ui(self, is_img2img):
|
||||
|
||||
with gradio.Box(): # adds padding so our components don't look out of place
|
||||
interrupt_all_btn = gradio.Button(value="Interrupt all remote workers")
|
||||
interrupt_all_btn.style(full_width=False)
|
||||
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
|
||||
with gradio.Accordion(label='Distributed', open=False) as main_accordian:
|
||||
|
||||
return [interrupt_all_btn]
|
||||
with gradio.Tab('Status') as status_tab:
|
||||
status = gradio.Textbox(elem_id='status', show_label=False)
|
||||
status_tab.select(fn=Script.ui_connect_test, inputs=[], outputs=[status])
|
||||
|
||||
refresh_status_btn = gradio.Button(value='Refresh')
|
||||
refresh_status_btn.style(size='sm')
|
||||
refresh_status_btn.click(Script.ui_connect_test, inputs=[], outputs=[status])
|
||||
|
||||
with gradio.Tab('Remote Utils'):
|
||||
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
|
||||
refresh_checkpoints_btn.style(full_width=False)
|
||||
refresh_checkpoints_btn.click(Script.ui_connect_refresh_ckpts_btn, inputs=[], outputs=[])
|
||||
|
||||
sync_models_btn = gradio.Button(value='Synchronize models')
|
||||
sync_models_btn.style(full_width=False)
|
||||
sync_models_btn.click(Script.user_sync_script, inputs=[], outputs=[])
|
||||
|
||||
interrupt_all_btn = gradio.Button(value='Interrupt all', variant='stop')
|
||||
interrupt_all_btn.style(full_width=False)
|
||||
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def user_sync_script():
|
||||
user_scripts = Path(os.path.abspath(__file__)).parent.joinpath('user')
|
||||
# user_script = user_scripts.joinpath('example.sh')
|
||||
for file in user_scripts.iterdir():
|
||||
if file.is_file() and file.name.startswith('sync'):
|
||||
user_script = file
|
||||
|
||||
suffix = user_script.suffix[1:]
|
||||
|
||||
if suffix == 'ps1':
|
||||
subprocess.call(['powershell', user_script])
|
||||
return True
|
||||
else:
|
||||
f = open(user_script, 'r')
|
||||
first_line = f.readline().strip()
|
||||
if first_line.startswith('#!'):
|
||||
shebang = first_line[2:]
|
||||
subprocess.call([shebang, user_script])
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# World is not constructed until the first generation job, so I use an intermediary call
|
||||
@staticmethod
|
||||
|
|
@ -67,11 +112,38 @@ class Script(scripts.Script):
|
|||
except AttributeError:
|
||||
print("Nothing to interrupt, Distributed system not initialized")
|
||||
|
||||
@staticmethod
|
||||
def ui_connect_refresh_ckpts_btn():
|
||||
try:
|
||||
Script.world.refresh_checkpoints()
|
||||
except AttributeError:
|
||||
print("Distributed system not initialized")
|
||||
|
||||
@staticmethod
|
||||
def ui_connect_test():
|
||||
try:
|
||||
temp = ''
|
||||
|
||||
for worker in Script.world.workers:
|
||||
if worker.master:
|
||||
continue
|
||||
|
||||
temp += f"{worker.uuid} at {worker.address} is {worker.state.name}\n"
|
||||
|
||||
return temp
|
||||
|
||||
# init system if it isn't already
|
||||
except AttributeError:
|
||||
# batch size will be clobbered later once an actual request is made anyway so I just pass 1
|
||||
Script.initialize(initial_payload=None)
|
||||
Script.ui_connect_test()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def add_to_gallery(processed, p):
|
||||
"""adds generated images to the image gallery after waiting for all workers to finish"""
|
||||
|
||||
# get master ipm by estimating based on worker speed
|
||||
global worker
|
||||
master_elapsed = time.time() - Script.master_start
|
||||
print(f"Took master {master_elapsed}s")
|
||||
|
||||
|
|
@ -148,26 +220,37 @@ class Script(scripts.Script):
|
|||
Script.unregister_callbacks()
|
||||
return
|
||||
|
||||
def run(self, p, *args):
|
||||
if cmd_opts.distributed_remotes is None:
|
||||
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
|
||||
|
||||
Script.world = World(initial_payload=p, verify_remotes=Script.verify_remotes)
|
||||
# add workers to the world
|
||||
for worker in cmd_opts.distributed_remotes:
|
||||
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
|
||||
# register gallery callback
|
||||
script_callbacks.on_after_image_processed(Script.add_to_gallery)
|
||||
|
||||
if self.verify_remotes is False:
|
||||
@staticmethod
|
||||
def initialize(initial_payload):
|
||||
if Script.verify_remotes is False:
|
||||
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
try:
|
||||
Script.world.initialize(p.batch_size)
|
||||
batch_size = initial_payload.batch_size
|
||||
except AttributeError:
|
||||
batch_size = 1
|
||||
|
||||
Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes)
|
||||
|
||||
# add workers to the world
|
||||
for worker in cmd_opts.distributed_remotes:
|
||||
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
|
||||
|
||||
try:
|
||||
Script.world.initialize(batch_size)
|
||||
print("World initialized!")
|
||||
except WorldAlreadyInitialized:
|
||||
Script.world.update_world(p.batch_size)
|
||||
Script.world.update_world(total_batch_size=batch_size)
|
||||
|
||||
def run(self, p, *args):
|
||||
if cmd_opts.distributed_remotes is None:
|
||||
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
|
||||
|
||||
# register gallery callback
|
||||
script_callbacks.on_after_image_processed(Script.add_to_gallery)
|
||||
|
||||
Script.initialize(initial_payload=p)
|
||||
|
||||
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works
|
||||
# see test/basic_features/txt2img_test.py for an example
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from webui import server_name
|
|||
from modules.shared import cmd_opts
|
||||
import gradio as gr
|
||||
from scripts.spartan.shared import benchmark_payload
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class InvalidWorkerResponse(Exception):
|
||||
|
|
@ -18,6 +19,12 @@ class InvalidWorkerResponse(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class State(Enum):
|
||||
IDLE = 1
|
||||
WORKING = 2
|
||||
INTERRUPTED = 3
|
||||
|
||||
|
||||
class Worker:
|
||||
"""
|
||||
This class represents a worker node in a distributed computing setup.
|
||||
|
|
@ -53,7 +60,7 @@ class Worker:
|
|||
response: requests.Response = None
|
||||
loaded_model: str = None
|
||||
loaded_vae: str = None
|
||||
interrupted: bool = False
|
||||
state: State = None
|
||||
|
||||
# Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A.
|
||||
# We compare to euler a because that is what we currently benchmark each node with.
|
||||
|
|
@ -100,6 +107,7 @@ class Worker:
|
|||
self.response_time = None
|
||||
self.loaded_model = ''
|
||||
self.loaded_vae = ''
|
||||
self.state = State.IDLE
|
||||
|
||||
if uuid is not None:
|
||||
self.uuid = uuid
|
||||
|
|
@ -251,6 +259,8 @@ class Worker:
|
|||
|
||||
# TODO detect remote out of memory exception and restart or garbage collect instance using api?
|
||||
try:
|
||||
self.state = State.WORKING
|
||||
|
||||
# query memory available on worker and store for future reference
|
||||
if self.queried is False:
|
||||
self.queried = True
|
||||
|
|
@ -290,7 +300,7 @@ class Worker:
|
|||
self.response = response.json()
|
||||
|
||||
# update list of ETA accuracy
|
||||
if self.benchmarked and not self.interrupted:
|
||||
if self.benchmarked and not self.state == State.INTERRUPTED:
|
||||
self.response_time = time.time() - start
|
||||
variance = ((eta - self.response_time) / self.response_time) * 100
|
||||
|
||||
|
|
@ -312,6 +322,7 @@ class Worker:
|
|||
except requests.exceptions.ConnectTimeout:
|
||||
print(f"\nTimed out waiting for worker '{self.uuid}' at {self}")
|
||||
|
||||
self.state = State.IDLE
|
||||
return
|
||||
|
||||
def benchmark(self) -> int:
|
||||
|
|
@ -342,6 +353,7 @@ class Worker:
|
|||
results: List[float] = []
|
||||
# it's seems to be lower for the first couple of generations
|
||||
# TODO look into how and why this "warmup" happens
|
||||
self.state = State.WORKING
|
||||
for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up"
|
||||
t = Thread(target=self.request, args=(benchmark_payload, None, False,))
|
||||
try: # if the worker is unreachable/offline then handle that here
|
||||
|
|
@ -372,8 +384,22 @@ class Worker:
|
|||
# noinspection PyTypeChecker
|
||||
self.response = None
|
||||
self.benchmarked = True
|
||||
self.state = State.IDLE
|
||||
return avg_ipm
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
|
||||
response = requests.post(
|
||||
self.full_url('refresh-checkpoints'),
|
||||
json={},
|
||||
verify=self.verify_remotes
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
self.state = State.INTERRUPTED
|
||||
if cmd_opts.distributed_debug:
|
||||
print(f"successfully refreshed checkpoints for worker '{self.uuid}'")
|
||||
|
||||
def interrupt(self):
|
||||
response = requests.post(
|
||||
self.full_url('interrupt'),
|
||||
|
|
@ -382,6 +408,6 @@ class Worker:
|
|||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
self.interrupted = True
|
||||
self.state = State.INTERRUPTED
|
||||
if cmd_opts.distributed_debug:
|
||||
print(f"successfully interrupted worker {self.uuid}")
|
||||
|
|
|
|||
|
|
@ -178,6 +178,15 @@ class World:
|
|||
t = Thread(target=worker.interrupt, args=())
|
||||
t.start()
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
threads: List[Thread] = []
|
||||
|
||||
for worker in self.workers:
|
||||
if worker.master:
|
||||
continue
|
||||
|
||||
t = Thread(target=worker.refresh_checkpoints, args=())
|
||||
t.start()
|
||||
|
||||
def benchmark(self):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue