add buttons for refreshing checkpoints, running synchronization script. add tab for showing state of workers

pull/2/head
unknown 2023-03-27 01:45:12 -05:00
parent 162b541d4e
commit 6ffef54e36
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
4 changed files with 140 additions and 22 deletions

View File

@ -17,6 +17,9 @@ import copy
from modules.images import save_image from modules.images import save_image
from modules.shared import cmd_opts from modules.shared import cmd_opts
import time import time
from pathlib import Path
import os
import subprocess
from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
from scripts.spartan.Worker import Worker from scripts.spartan.Worker import Worker
from modules.shared import opts from modules.shared import opts
@ -53,11 +56,53 @@ class Script(scripts.Script):
def ui(self, is_img2img): def ui(self, is_img2img):
with gradio.Box(): # adds padding so our components don't look out of place with gradio.Box(): # adds padding so our components don't look out of place
interrupt_all_btn = gradio.Button(value="Interrupt all remote workers") with gradio.Accordion(label='Distributed', open=False) as main_accordian:
with gradio.Tab('Status') as status_tab:
status = gradio.Textbox(elem_id='status', show_label=False)
status_tab.select(fn=Script.ui_connect_test, inputs=[], outputs=[status])
refresh_status_btn = gradio.Button(value='Refresh')
refresh_status_btn.style(size='sm')
refresh_status_btn.click(Script.ui_connect_test, inputs=[], outputs=[status])
with gradio.Tab('Remote Utils'):
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
refresh_checkpoints_btn.style(full_width=False)
refresh_checkpoints_btn.click(Script.ui_connect_refresh_ckpts_btn, inputs=[], outputs=[])
sync_models_btn = gradio.Button(value='Synchronize models')
sync_models_btn.style(full_width=False)
sync_models_btn.click(Script.user_sync_script, inputs=[], outputs=[])
interrupt_all_btn = gradio.Button(value='Interrupt all', variant='stop')
interrupt_all_btn.style(full_width=False) interrupt_all_btn.style(full_width=False)
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[]) interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
return [interrupt_all_btn] return
@staticmethod
def user_sync_script():
user_scripts = Path(os.path.abspath(__file__)).parent.joinpath('user')
# user_script = user_scripts.joinpath('example.sh')
for file in user_scripts.iterdir():
if file.is_file() and file.name.startswith('sync'):
user_script = file
suffix = user_script.suffix[1:]
if suffix == 'ps1':
subprocess.call(['powershell', user_script])
return True
else:
f = open(user_script, 'r')
first_line = f.readline().strip()
if first_line.startswith('#!'):
shebang = first_line[2:]
subprocess.call([shebang, user_script])
return True
return False
# World is not constructed until the first generation job, so I use an intermediary call # World is not constructed until the first generation job, so I use an intermediary call
@staticmethod @staticmethod
@ -67,11 +112,38 @@ class Script(scripts.Script):
except AttributeError: except AttributeError:
print("Nothing to interrupt, Distributed system not initialized") print("Nothing to interrupt, Distributed system not initialized")
@staticmethod
def ui_connect_refresh_ckpts_btn():
try:
Script.world.refresh_checkpoints()
except AttributeError:
print("Distributed system not initialized")
@staticmethod
def ui_connect_test():
try:
temp = ''
for worker in Script.world.workers:
if worker.master:
continue
temp += f"{worker.uuid} at {worker.address} is {worker.state.name}\n"
return temp
# init system if it isn't already
except AttributeError:
# batch size will be clobbered later once an actual request is made anyway so I just pass 1
Script.initialize(initial_payload=None)
Script.ui_connect_test()
@staticmethod @staticmethod
def add_to_gallery(processed, p): def add_to_gallery(processed, p):
"""adds generated images to the image gallery after waiting for all workers to finish""" """adds generated images to the image gallery after waiting for all workers to finish"""
# get master ipm by estimating based on worker speed # get master ipm by estimating based on worker speed
global worker
master_elapsed = time.time() - Script.master_start master_elapsed = time.time() - Script.master_start
print(f"Took master {master_elapsed}s") print(f"Took master {master_elapsed}s")
@ -148,26 +220,37 @@ class Script(scripts.Script):
Script.unregister_callbacks() Script.unregister_callbacks()
return return
def run(self, p, *args): @staticmethod
if cmd_opts.distributed_remotes is None: def initialize(initial_payload):
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)") if Script.verify_remotes is False:
Script.world = World(initial_payload=p, verify_remotes=Script.verify_remotes)
# add workers to the world
for worker in cmd_opts.distributed_remotes:
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
# register gallery callback
script_callbacks.on_after_image_processed(Script.add_to_gallery)
if self.verify_remotes is False:
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates") print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
try: try:
Script.world.initialize(p.batch_size) batch_size = initial_payload.batch_size
except AttributeError:
batch_size = 1
Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes)
# add workers to the world
for worker in cmd_opts.distributed_remotes:
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
try:
Script.world.initialize(batch_size)
print("World initialized!") print("World initialized!")
except WorldAlreadyInitialized: except WorldAlreadyInitialized:
Script.world.update_world(p.batch_size) Script.world.update_world(total_batch_size=batch_size)
def run(self, p, *args):
if cmd_opts.distributed_remotes is None:
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
# register gallery callback
script_callbacks.on_after_image_processed(Script.add_to_gallery)
Script.initialize(initial_payload=p)
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works # encapsulating the request object within a txt2imgreq object is deprecated and no longer works
# see test/basic_features/txt2img_test.py for an example # see test/basic_features/txt2img_test.py for an example

View File

@ -9,6 +9,7 @@ from webui import server_name
from modules.shared import cmd_opts from modules.shared import cmd_opts
import gradio as gr import gradio as gr
from scripts.spartan.shared import benchmark_payload from scripts.spartan.shared import benchmark_payload
from enum import Enum
class InvalidWorkerResponse(Exception): class InvalidWorkerResponse(Exception):
@ -18,6 +19,12 @@ class InvalidWorkerResponse(Exception):
pass pass
class State(Enum):
IDLE = 1
WORKING = 2
INTERRUPTED = 3
class Worker: class Worker:
""" """
This class represents a worker node in a distributed computing setup. This class represents a worker node in a distributed computing setup.
@ -53,7 +60,7 @@ class Worker:
response: requests.Response = None response: requests.Response = None
loaded_model: str = None loaded_model: str = None
loaded_vae: str = None loaded_vae: str = None
interrupted: bool = False state: State = None
# Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A. # Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A.
# We compare to euler a because that is what we currently benchmark each node with. # We compare to euler a because that is what we currently benchmark each node with.
@ -100,6 +107,7 @@ class Worker:
self.response_time = None self.response_time = None
self.loaded_model = '' self.loaded_model = ''
self.loaded_vae = '' self.loaded_vae = ''
self.state = State.IDLE
if uuid is not None: if uuid is not None:
self.uuid = uuid self.uuid = uuid
@ -251,6 +259,8 @@ class Worker:
# TODO detect remote out of memory exception and restart or garbage collect instance using api? # TODO detect remote out of memory exception and restart or garbage collect instance using api?
try: try:
self.state = State.WORKING
# query memory available on worker and store for future reference # query memory available on worker and store for future reference
if self.queried is False: if self.queried is False:
self.queried = True self.queried = True
@ -290,7 +300,7 @@ class Worker:
self.response = response.json() self.response = response.json()
# update list of ETA accuracy # update list of ETA accuracy
if self.benchmarked and not self.interrupted: if self.benchmarked and not self.state == State.INTERRUPTED:
self.response_time = time.time() - start self.response_time = time.time() - start
variance = ((eta - self.response_time) / self.response_time) * 100 variance = ((eta - self.response_time) / self.response_time) * 100
@ -312,6 +322,7 @@ class Worker:
except requests.exceptions.ConnectTimeout: except requests.exceptions.ConnectTimeout:
print(f"\nTimed out waiting for worker '{self.uuid}' at {self}") print(f"\nTimed out waiting for worker '{self.uuid}' at {self}")
self.state = State.IDLE
return return
def benchmark(self) -> int: def benchmark(self) -> int:
@ -342,6 +353,7 @@ class Worker:
results: List[float] = [] results: List[float] = []
# it's seems to be lower for the first couple of generations # it's seems to be lower for the first couple of generations
# TODO look into how and why this "warmup" happens # TODO look into how and why this "warmup" happens
self.state = State.WORKING
for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up" for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up"
t = Thread(target=self.request, args=(benchmark_payload, None, False,)) t = Thread(target=self.request, args=(benchmark_payload, None, False,))
try: # if the worker is unreachable/offline then handle that here try: # if the worker is unreachable/offline then handle that here
@ -372,8 +384,22 @@ class Worker:
# noinspection PyTypeChecker # noinspection PyTypeChecker
self.response = None self.response = None
self.benchmarked = True self.benchmarked = True
self.state = State.IDLE
return avg_ipm return avg_ipm
def refresh_checkpoints(self):
response = requests.post(
self.full_url('refresh-checkpoints'),
json={},
verify=self.verify_remotes
)
if response.status_code == 200:
self.state = State.INTERRUPTED
if cmd_opts.distributed_debug:
print(f"successfully refreshed checkpoints for worker '{self.uuid}'")
def interrupt(self): def interrupt(self):
response = requests.post( response = requests.post(
self.full_url('interrupt'), self.full_url('interrupt'),
@ -382,6 +408,6 @@ class Worker:
) )
if response.status_code == 200: if response.status_code == 200:
self.interrupted = True self.state = State.INTERRUPTED
if cmd_opts.distributed_debug: if cmd_opts.distributed_debug:
print(f"successfully interrupted worker {self.uuid}") print(f"successfully interrupted worker {self.uuid}")

View File

@ -178,6 +178,15 @@ class World:
t = Thread(target=worker.interrupt, args=()) t = Thread(target=worker.interrupt, args=())
t.start() t.start()
def refresh_checkpoints(self):
threads: List[Thread] = []
for worker in self.workers:
if worker.master:
continue
t = Thread(target=worker.refresh_checkpoints, args=())
t.start()
def benchmark(self): def benchmark(self):
""" """

View File