add buttons for refreshing checkpoints, running synchronization script. add tab for showing state of workers
parent
162b541d4e
commit
6ffef54e36
|
|
@ -17,6 +17,9 @@ import copy
|
||||||
from modules.images import save_image
|
from modules.images import save_image
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
|
from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
|
||||||
from scripts.spartan.Worker import Worker
|
from scripts.spartan.Worker import Worker
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
@ -53,11 +56,53 @@ class Script(scripts.Script):
|
||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
|
|
||||||
with gradio.Box(): # adds padding so our components don't look out of place
|
with gradio.Box(): # adds padding so our components don't look out of place
|
||||||
interrupt_all_btn = gradio.Button(value="Interrupt all remote workers")
|
with gradio.Accordion(label='Distributed', open=False) as main_accordian:
|
||||||
interrupt_all_btn.style(full_width=False)
|
|
||||||
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
|
|
||||||
|
|
||||||
return [interrupt_all_btn]
|
with gradio.Tab('Status') as status_tab:
|
||||||
|
status = gradio.Textbox(elem_id='status', show_label=False)
|
||||||
|
status_tab.select(fn=Script.ui_connect_test, inputs=[], outputs=[status])
|
||||||
|
|
||||||
|
refresh_status_btn = gradio.Button(value='Refresh')
|
||||||
|
refresh_status_btn.style(size='sm')
|
||||||
|
refresh_status_btn.click(Script.ui_connect_test, inputs=[], outputs=[status])
|
||||||
|
|
||||||
|
with gradio.Tab('Remote Utils'):
|
||||||
|
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
|
||||||
|
refresh_checkpoints_btn.style(full_width=False)
|
||||||
|
refresh_checkpoints_btn.click(Script.ui_connect_refresh_ckpts_btn, inputs=[], outputs=[])
|
||||||
|
|
||||||
|
sync_models_btn = gradio.Button(value='Synchronize models')
|
||||||
|
sync_models_btn.style(full_width=False)
|
||||||
|
sync_models_btn.click(Script.user_sync_script, inputs=[], outputs=[])
|
||||||
|
|
||||||
|
interrupt_all_btn = gradio.Button(value='Interrupt all', variant='stop')
|
||||||
|
interrupt_all_btn.style(full_width=False)
|
||||||
|
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def user_sync_script():
|
||||||
|
user_scripts = Path(os.path.abspath(__file__)).parent.joinpath('user')
|
||||||
|
# user_script = user_scripts.joinpath('example.sh')
|
||||||
|
for file in user_scripts.iterdir():
|
||||||
|
if file.is_file() and file.name.startswith('sync'):
|
||||||
|
user_script = file
|
||||||
|
|
||||||
|
suffix = user_script.suffix[1:]
|
||||||
|
|
||||||
|
if suffix == 'ps1':
|
||||||
|
subprocess.call(['powershell', user_script])
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
f = open(user_script, 'r')
|
||||||
|
first_line = f.readline().strip()
|
||||||
|
if first_line.startswith('#!'):
|
||||||
|
shebang = first_line[2:]
|
||||||
|
subprocess.call([shebang, user_script])
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
# World is not constructed until the first generation job, so I use an intermediary call
|
# World is not constructed until the first generation job, so I use an intermediary call
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -67,11 +112,38 @@ class Script(scripts.Script):
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
print("Nothing to interrupt, Distributed system not initialized")
|
print("Nothing to interrupt, Distributed system not initialized")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ui_connect_refresh_ckpts_btn():
|
||||||
|
try:
|
||||||
|
Script.world.refresh_checkpoints()
|
||||||
|
except AttributeError:
|
||||||
|
print("Distributed system not initialized")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ui_connect_test():
|
||||||
|
try:
|
||||||
|
temp = ''
|
||||||
|
|
||||||
|
for worker in Script.world.workers:
|
||||||
|
if worker.master:
|
||||||
|
continue
|
||||||
|
|
||||||
|
temp += f"{worker.uuid} at {worker.address} is {worker.state.name}\n"
|
||||||
|
|
||||||
|
return temp
|
||||||
|
|
||||||
|
# init system if it isn't already
|
||||||
|
except AttributeError:
|
||||||
|
# batch size will be clobbered later once an actual request is made anyway so I just pass 1
|
||||||
|
Script.initialize(initial_payload=None)
|
||||||
|
Script.ui_connect_test()
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_to_gallery(processed, p):
|
def add_to_gallery(processed, p):
|
||||||
"""adds generated images to the image gallery after waiting for all workers to finish"""
|
"""adds generated images to the image gallery after waiting for all workers to finish"""
|
||||||
|
|
||||||
# get master ipm by estimating based on worker speed
|
# get master ipm by estimating based on worker speed
|
||||||
global worker
|
|
||||||
master_elapsed = time.time() - Script.master_start
|
master_elapsed = time.time() - Script.master_start
|
||||||
print(f"Took master {master_elapsed}s")
|
print(f"Took master {master_elapsed}s")
|
||||||
|
|
||||||
|
|
@ -148,26 +220,37 @@ class Script(scripts.Script):
|
||||||
Script.unregister_callbacks()
|
Script.unregister_callbacks()
|
||||||
return
|
return
|
||||||
|
|
||||||
def run(self, p, *args):
|
@staticmethod
|
||||||
if cmd_opts.distributed_remotes is None:
|
def initialize(initial_payload):
|
||||||
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
|
if Script.verify_remotes is False:
|
||||||
|
|
||||||
Script.world = World(initial_payload=p, verify_remotes=Script.verify_remotes)
|
|
||||||
# add workers to the world
|
|
||||||
for worker in cmd_opts.distributed_remotes:
|
|
||||||
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
|
|
||||||
# register gallery callback
|
|
||||||
script_callbacks.on_after_image_processed(Script.add_to_gallery)
|
|
||||||
|
|
||||||
if self.verify_remotes is False:
|
|
||||||
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
|
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
Script.world.initialize(p.batch_size)
|
batch_size = initial_payload.batch_size
|
||||||
|
except AttributeError:
|
||||||
|
batch_size = 1
|
||||||
|
|
||||||
|
Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes)
|
||||||
|
|
||||||
|
# add workers to the world
|
||||||
|
for worker in cmd_opts.distributed_remotes:
|
||||||
|
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
|
||||||
|
|
||||||
|
try:
|
||||||
|
Script.world.initialize(batch_size)
|
||||||
print("World initialized!")
|
print("World initialized!")
|
||||||
except WorldAlreadyInitialized:
|
except WorldAlreadyInitialized:
|
||||||
Script.world.update_world(p.batch_size)
|
Script.world.update_world(total_batch_size=batch_size)
|
||||||
|
|
||||||
|
def run(self, p, *args):
|
||||||
|
if cmd_opts.distributed_remotes is None:
|
||||||
|
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
|
||||||
|
|
||||||
|
# register gallery callback
|
||||||
|
script_callbacks.on_after_image_processed(Script.add_to_gallery)
|
||||||
|
|
||||||
|
Script.initialize(initial_payload=p)
|
||||||
|
|
||||||
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works
|
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works
|
||||||
# see test/basic_features/txt2img_test.py for an example
|
# see test/basic_features/txt2img_test.py for an example
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from webui import server_name
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from scripts.spartan.shared import benchmark_payload
|
from scripts.spartan.shared import benchmark_payload
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class InvalidWorkerResponse(Exception):
|
class InvalidWorkerResponse(Exception):
|
||||||
|
|
@ -18,6 +19,12 @@ class InvalidWorkerResponse(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class State(Enum):
|
||||||
|
IDLE = 1
|
||||||
|
WORKING = 2
|
||||||
|
INTERRUPTED = 3
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
"""
|
"""
|
||||||
This class represents a worker node in a distributed computing setup.
|
This class represents a worker node in a distributed computing setup.
|
||||||
|
|
@ -53,7 +60,7 @@ class Worker:
|
||||||
response: requests.Response = None
|
response: requests.Response = None
|
||||||
loaded_model: str = None
|
loaded_model: str = None
|
||||||
loaded_vae: str = None
|
loaded_vae: str = None
|
||||||
interrupted: bool = False
|
state: State = None
|
||||||
|
|
||||||
# Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A.
|
# Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A.
|
||||||
# We compare to euler a because that is what we currently benchmark each node with.
|
# We compare to euler a because that is what we currently benchmark each node with.
|
||||||
|
|
@ -100,6 +107,7 @@ class Worker:
|
||||||
self.response_time = None
|
self.response_time = None
|
||||||
self.loaded_model = ''
|
self.loaded_model = ''
|
||||||
self.loaded_vae = ''
|
self.loaded_vae = ''
|
||||||
|
self.state = State.IDLE
|
||||||
|
|
||||||
if uuid is not None:
|
if uuid is not None:
|
||||||
self.uuid = uuid
|
self.uuid = uuid
|
||||||
|
|
@ -251,6 +259,8 @@ class Worker:
|
||||||
|
|
||||||
# TODO detect remote out of memory exception and restart or garbage collect instance using api?
|
# TODO detect remote out of memory exception and restart or garbage collect instance using api?
|
||||||
try:
|
try:
|
||||||
|
self.state = State.WORKING
|
||||||
|
|
||||||
# query memory available on worker and store for future reference
|
# query memory available on worker and store for future reference
|
||||||
if self.queried is False:
|
if self.queried is False:
|
||||||
self.queried = True
|
self.queried = True
|
||||||
|
|
@ -290,7 +300,7 @@ class Worker:
|
||||||
self.response = response.json()
|
self.response = response.json()
|
||||||
|
|
||||||
# update list of ETA accuracy
|
# update list of ETA accuracy
|
||||||
if self.benchmarked and not self.interrupted:
|
if self.benchmarked and not self.state == State.INTERRUPTED:
|
||||||
self.response_time = time.time() - start
|
self.response_time = time.time() - start
|
||||||
variance = ((eta - self.response_time) / self.response_time) * 100
|
variance = ((eta - self.response_time) / self.response_time) * 100
|
||||||
|
|
||||||
|
|
@ -312,6 +322,7 @@ class Worker:
|
||||||
except requests.exceptions.ConnectTimeout:
|
except requests.exceptions.ConnectTimeout:
|
||||||
print(f"\nTimed out waiting for worker '{self.uuid}' at {self}")
|
print(f"\nTimed out waiting for worker '{self.uuid}' at {self}")
|
||||||
|
|
||||||
|
self.state = State.IDLE
|
||||||
return
|
return
|
||||||
|
|
||||||
def benchmark(self) -> int:
|
def benchmark(self) -> int:
|
||||||
|
|
@ -342,6 +353,7 @@ class Worker:
|
||||||
results: List[float] = []
|
results: List[float] = []
|
||||||
# it's seems to be lower for the first couple of generations
|
# it's seems to be lower for the first couple of generations
|
||||||
# TODO look into how and why this "warmup" happens
|
# TODO look into how and why this "warmup" happens
|
||||||
|
self.state = State.WORKING
|
||||||
for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up"
|
for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up"
|
||||||
t = Thread(target=self.request, args=(benchmark_payload, None, False,))
|
t = Thread(target=self.request, args=(benchmark_payload, None, False,))
|
||||||
try: # if the worker is unreachable/offline then handle that here
|
try: # if the worker is unreachable/offline then handle that here
|
||||||
|
|
@ -372,8 +384,22 @@ class Worker:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
self.response = None
|
self.response = None
|
||||||
self.benchmarked = True
|
self.benchmarked = True
|
||||||
|
self.state = State.IDLE
|
||||||
return avg_ipm
|
return avg_ipm
|
||||||
|
|
||||||
|
def refresh_checkpoints(self):
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.full_url('refresh-checkpoints'),
|
||||||
|
json={},
|
||||||
|
verify=self.verify_remotes
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
self.state = State.INTERRUPTED
|
||||||
|
if cmd_opts.distributed_debug:
|
||||||
|
print(f"successfully refreshed checkpoints for worker '{self.uuid}'")
|
||||||
|
|
||||||
def interrupt(self):
|
def interrupt(self):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.full_url('interrupt'),
|
self.full_url('interrupt'),
|
||||||
|
|
@ -382,6 +408,6 @@ class Worker:
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
self.interrupted = True
|
self.state = State.INTERRUPTED
|
||||||
if cmd_opts.distributed_debug:
|
if cmd_opts.distributed_debug:
|
||||||
print(f"successfully interrupted worker {self.uuid}")
|
print(f"successfully interrupted worker {self.uuid}")
|
||||||
|
|
|
||||||
|
|
@ -178,6 +178,15 @@ class World:
|
||||||
t = Thread(target=worker.interrupt, args=())
|
t = Thread(target=worker.interrupt, args=())
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
|
def refresh_checkpoints(self):
|
||||||
|
threads: List[Thread] = []
|
||||||
|
|
||||||
|
for worker in self.workers:
|
||||||
|
if worker.master:
|
||||||
|
continue
|
||||||
|
|
||||||
|
t = Thread(target=worker.refresh_checkpoints, args=())
|
||||||
|
t.start()
|
||||||
|
|
||||||
def benchmark(self):
|
def benchmark(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue