add buttons for refreshing checkpoints, running synchronization script. add tab for showing state of workers

pull/2/head
unknown 2023-03-27 01:45:12 -05:00
parent 162b541d4e
commit 6ffef54e36
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
4 changed files with 140 additions and 22 deletions

View File

@ -17,6 +17,9 @@ import copy
from modules.images import save_image
from modules.shared import cmd_opts
import time
from pathlib import Path
import os
import subprocess
from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
from scripts.spartan.Worker import Worker
from modules.shared import opts
@ -53,11 +56,53 @@ class Script(scripts.Script):
def ui(self, is_img2img):
with gradio.Box(): # adds padding so our components don't look out of place
interrupt_all_btn = gradio.Button(value="Interrupt all remote workers")
interrupt_all_btn.style(full_width=False)
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
with gradio.Accordion(label='Distributed', open=False) as main_accordian:
return [interrupt_all_btn]
with gradio.Tab('Status') as status_tab:
status = gradio.Textbox(elem_id='status', show_label=False)
status_tab.select(fn=Script.ui_connect_test, inputs=[], outputs=[status])
refresh_status_btn = gradio.Button(value='Refresh')
refresh_status_btn.style(size='sm')
refresh_status_btn.click(Script.ui_connect_test, inputs=[], outputs=[status])
with gradio.Tab('Remote Utils'):
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
refresh_checkpoints_btn.style(full_width=False)
refresh_checkpoints_btn.click(Script.ui_connect_refresh_ckpts_btn, inputs=[], outputs=[])
sync_models_btn = gradio.Button(value='Synchronize models')
sync_models_btn.style(full_width=False)
sync_models_btn.click(Script.user_sync_script, inputs=[], outputs=[])
interrupt_all_btn = gradio.Button(value='Interrupt all', variant='stop')
interrupt_all_btn.style(full_width=False)
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
return
@staticmethod
def user_sync_script():
user_scripts = Path(os.path.abspath(__file__)).parent.joinpath('user')
# user_script = user_scripts.joinpath('example.sh')
for file in user_scripts.iterdir():
if file.is_file() and file.name.startswith('sync'):
user_script = file
suffix = user_script.suffix[1:]
if suffix == 'ps1':
subprocess.call(['powershell', user_script])
return True
else:
f = open(user_script, 'r')
first_line = f.readline().strip()
if first_line.startswith('#!'):
shebang = first_line[2:]
subprocess.call([shebang, user_script])
return True
return False
# World is not constructed until the first generation job, so I use an intermediary call
@staticmethod
@ -67,11 +112,38 @@ class Script(scripts.Script):
except AttributeError:
print("Nothing to interrupt, Distributed system not initialized")
@staticmethod
def ui_connect_refresh_ckpts_btn():
try:
Script.world.refresh_checkpoints()
except AttributeError:
print("Distributed system not initialized")
@staticmethod
def ui_connect_test():
try:
temp = ''
for worker in Script.world.workers:
if worker.master:
continue
temp += f"{worker.uuid} at {worker.address} is {worker.state.name}\n"
return temp
# init system if it isn't already
except AttributeError:
# batch size will be clobbered later once an actual request is made anyway so I just pass 1
Script.initialize(initial_payload=None)
Script.ui_connect_test()
@staticmethod
def add_to_gallery(processed, p):
"""adds generated images to the image gallery after waiting for all workers to finish"""
# get master ipm by estimating based on worker speed
global worker
master_elapsed = time.time() - Script.master_start
print(f"Took master {master_elapsed}s")
@ -148,26 +220,37 @@ class Script(scripts.Script):
Script.unregister_callbacks()
return
def run(self, p, *args):
if cmd_opts.distributed_remotes is None:
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
Script.world = World(initial_payload=p, verify_remotes=Script.verify_remotes)
# add workers to the world
for worker in cmd_opts.distributed_remotes:
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
# register gallery callback
script_callbacks.on_after_image_processed(Script.add_to_gallery)
if self.verify_remotes is False:
@staticmethod
def initialize(initial_payload):
if Script.verify_remotes is False:
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
try:
Script.world.initialize(p.batch_size)
batch_size = initial_payload.batch_size
except AttributeError:
batch_size = 1
Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes)
# add workers to the world
for worker in cmd_opts.distributed_remotes:
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
try:
Script.world.initialize(batch_size)
print("World initialized!")
except WorldAlreadyInitialized:
Script.world.update_world(p.batch_size)
Script.world.update_world(total_batch_size=batch_size)
def run(self, p, *args):
if cmd_opts.distributed_remotes is None:
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
# register gallery callback
script_callbacks.on_after_image_processed(Script.add_to_gallery)
Script.initialize(initial_payload=p)
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works
# see test/basic_features/txt2img_test.py for an example

View File

@ -9,6 +9,7 @@ from webui import server_name
from modules.shared import cmd_opts
import gradio as gr
from scripts.spartan.shared import benchmark_payload
from enum import Enum
class InvalidWorkerResponse(Exception):
@ -18,6 +19,12 @@ class InvalidWorkerResponse(Exception):
pass
class State(Enum):
IDLE = 1
WORKING = 2
INTERRUPTED = 3
class Worker:
"""
This class represents a worker node in a distributed computing setup.
@ -53,7 +60,7 @@ class Worker:
response: requests.Response = None
loaded_model: str = None
loaded_vae: str = None
interrupted: bool = False
state: State = None
# Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A.
# We compare to euler a because that is what we currently benchmark each node with.
@ -100,6 +107,7 @@ class Worker:
self.response_time = None
self.loaded_model = ''
self.loaded_vae = ''
self.state = State.IDLE
if uuid is not None:
self.uuid = uuid
@ -251,6 +259,8 @@ class Worker:
# TODO detect remote out of memory exception and restart or garbage collect instance using api?
try:
self.state = State.WORKING
# query memory available on worker and store for future reference
if self.queried is False:
self.queried = True
@ -290,7 +300,7 @@ class Worker:
self.response = response.json()
# update list of ETA accuracy
if self.benchmarked and not self.interrupted:
if self.benchmarked and not self.state == State.INTERRUPTED:
self.response_time = time.time() - start
variance = ((eta - self.response_time) / self.response_time) * 100
@ -312,6 +322,7 @@ class Worker:
except requests.exceptions.ConnectTimeout:
print(f"\nTimed out waiting for worker '{self.uuid}' at {self}")
self.state = State.IDLE
return
def benchmark(self) -> int:
@ -342,6 +353,7 @@ class Worker:
results: List[float] = []
# it's seems to be lower for the first couple of generations
# TODO look into how and why this "warmup" happens
self.state = State.WORKING
for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up"
t = Thread(target=self.request, args=(benchmark_payload, None, False,))
try: # if the worker is unreachable/offline then handle that here
@ -372,8 +384,22 @@ class Worker:
# noinspection PyTypeChecker
self.response = None
self.benchmarked = True
self.state = State.IDLE
return avg_ipm
def refresh_checkpoints(self):
response = requests.post(
self.full_url('refresh-checkpoints'),
json={},
verify=self.verify_remotes
)
if response.status_code == 200:
self.state = State.INTERRUPTED
if cmd_opts.distributed_debug:
print(f"successfully refreshed checkpoints for worker '{self.uuid}'")
def interrupt(self):
response = requests.post(
self.full_url('interrupt'),
@ -382,6 +408,6 @@ class Worker:
)
if response.status_code == 200:
self.interrupted = True
self.state = State.INTERRUPTED
if cmd_opts.distributed_debug:
print(f"successfully interrupted worker {self.uuid}")

View File

@ -178,6 +178,15 @@ class World:
t = Thread(target=worker.interrupt, args=())
t.start()
def refresh_checkpoints(self):
threads: List[Thread] = []
for worker in self.workers:
if worker.master:
continue
t = Thread(target=worker.refresh_checkpoints, args=())
t.start()
def benchmark(self):
"""

View File