add ability to remove workers from within the UI. fix regression preventing benchmarking when config exists but recorded ipm values are invalid.

pull/17/head
unknown 2023-07-08 22:01:19 -05:00
parent 80852253c9
commit 3f38d463ac
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
2 changed files with 79 additions and 46 deletions

View File

@ -4,12 +4,15 @@ import subprocess
from pathlib import Path
import gradio
from scripts.spartan.shared import logger, log_level
from scripts.spartan.Worker import State
from scripts.spartan.Worker import State, Worker
from modules.shared import state as webui_state
import json
from typing import List
worker_select_dropdown = None
class UI:
def __init__(self, script, world):
self.script = script
self.world = world
@ -73,23 +76,36 @@ class UI:
logger.debug(f"job timeout is now {job_timeout} seconds")
def save_worker_btn(self, name, address, port, tls):
worker = self.world.add_worker(name, address, port, tls)
self.world.add_worker(name, address, port, tls)
self.world.save_config()
workers_info = {}
with open(self.world.worker_info_path, 'r', encoding='utf-8') as worker_info_file:
try:
workers_info = json.load(worker_info_file)
except json.decoder.JSONDecodeError:
logger.error(f"corrupt or invalid config file... ignoring")
except io.UnsupportedOperation:
pass
# visibly update which workers can be selected
labels = [x.uuid for x in self.selectable_remote_workers()]
return gradio.Dropdown.update(choices=labels)
with open(self.world.worker_info_path, 'w', encoding='utf-8') as worker_info_file:
inf: dict = worker.info()
workers_info[name] = inf[name]
def selectable_remote_workers(self) -> List[Worker]:
remote_workers = []
json.dump(workers_info, worker_info_file, indent=3)
for worker in self.world.get_workers():
if worker.master:
continue
remote_workers.append(worker)
remote_workers = sorted(remote_workers, key=lambda x: x.uuid)
return remote_workers
def remove_worker_btn(self, worker_label):
# remove worker from memory
for worker in self.world._workers:
if worker.uuid == worker_label:
self.world._workers.remove(worker)
# remove worker from disk
self.world.save_config()
# visibly update which workers can be selected
labels = [x.uuid for x in self.selectable_remote_workers()]
return gradio.Dropdown.update(choices=labels)
# end handlers
@ -135,16 +151,25 @@ class UI:
clear_queue_btn.click(self.clear_queue_btn)
with gradio.Tab('Worker Config'):
worker_name_field = gradio.Textbox(label='Name')
worker_select_dropdown = None
worker_select_dropdown = gradio.Dropdown(
[x.uuid for x in self.selectable_remote_workers()],
info='Select a pre-existing worker or enter a label for a new one',
label='Label',
allow_custom_value=True
)
worker_address_field = gradio.Textbox(label='Address')
worker_port_field = gradio.Textbox(label='Port', value='7860')
worker_tls_cbx = gradio.Checkbox(
label='connect to worker using https'
)
save_worker_btn = gradio.Button(
value='Add Worker'
)
save_worker_btn.click(self.save_worker_btn, inputs=[worker_name_field, worker_address_field, worker_port_field, worker_tls_cbx])
with gradio.Row():
save_worker_btn = gradio.Button(value='Add/Update Worker')
save_worker_btn.click(self.save_worker_btn, inputs=[worker_select_dropdown, worker_address_field, worker_port_field, worker_tls_cbx], outputs=[worker_select_dropdown])
remove_worker_btn = gradio.Button(value='Remove Worker', variant='stop')
remove_worker_btn.click(self.remove_worker_btn, inputs=worker_select_dropdown, outputs=[worker_select_dropdown])
with gradio.Tab('Settings'):
thin_client_cbx = gradio.Checkbox(

View File

@ -75,7 +75,7 @@ class World:
def __init__(self, initial_payload, verify_remotes: bool = True):
self.master_worker = Worker(master=True)
self.total_batch_size: int = 0
self.__workers: List[Worker] = [self.master_worker]
self._workers: List[Worker] = [self.master_worker]
self.jobs: List[Job] = []
self.job_timeout: int = 6 # seconds
self.initialized: bool = False
@ -149,10 +149,10 @@ class World:
worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes, tls=tls)
for w in self.__workers:
for w in self._workers:
if w.uuid == uuid:
self.__workers.remove(w)
self.__workers.append(worker)
self._workers.remove(w)
self._workers.append(worker)
return worker
@ -220,18 +220,17 @@ class World:
if saved and not rebenchmark:
logger.debug(f"loaded saved configuration: \n{workers_info}")
for worker in self.__workers:
for worker in self._workers:
try:
worker.avg_ipm = workers_info[worker.uuid]['avg_ipm']
if worker.avg_ipm <= 0:
logger.debug(f"{worker.uuid} has recorded ipm of 0... marking as unbenched")
if worker.avg_ipm is None or worker.avg_ipm <= 0:
logger.debug(f"{worker.uuid} recorded ipm is invalid... marking as unbenched")
unbenched_workers.append(worker)
else:
worker.benchmarked = True
except KeyError:
logger.debug(f"worker '{worker.uuid}' not found in workers.json")
unbenched_workers.append(worker)
return
else:
unbenched_workers = self.get_workers()
@ -244,19 +243,13 @@ class World:
# wait for all benchmarks to finish and update stats on newly benchmarked workers
if len(benchmark_threads) > 0:
with open(self.worker_info_path, 'w') as worker_info_file:
for t in benchmark_threads:
t.join()
logger.info("Benchmarking finished")
for t in benchmark_threads:
t.join()
logger.info("Benchmarking finished")
for worker in unbenched_workers:
workers_info.update(worker.info())
workers_info.update({'benchmark_payload': sh.benchmark_payload})
# save benchmark results to workers.json
json.dump(workers_info, worker_info_file, indent=3)
logger.info(self.speed_summary())
# save benchmark results to workers.json
self.save_config()
logger.info(self.speed_summary())
def get_current_output_size(self) -> int:
"""
@ -274,7 +267,7 @@ class World:
"""
Returns string listing workers by their ipm in descending order.
"""
workers_copy = copy.deepcopy(self.__workers)
workers_copy = copy.deepcopy(self._workers)
workers_copy.sort(key=lambda w: w.avg_ipm, reverse=True)
total_ipm = 0
@ -393,7 +386,7 @@ class World:
def get_workers(self):
filtered = []
for worker in self.__workers:
for worker in self._workers:
if worker.avg_ipm is not None and worker.avg_ipm <= 0:
logger.warning(f"config reports invalid speed (0 ipm) for worker '{worker.uuid}', setting default of 1 ipm.\nplease re-benchmark")
worker.avg_ipm = 1
@ -505,7 +498,7 @@ class World:
del self.jobs[last]
last -= 1
def load_config(self):
def config(self) -> json:
if not os.path.exists(self.worker_info_path):
logger.debug(f"Config was not found at '{self.worker_info_path}'")
return
@ -513,16 +506,19 @@ class World:
with open(self.worker_info_path, 'r') as config:
try:
config_json = json.load(config)
return json.load(config)
except json.decoder.JSONDecodeError:
logger.debug(f"config is corrupt or invalid JSON, unable to load")
return
for key in config_json:
def load_config(self):
config = self.config()
if config is not None:
for key in config:
if key == "benchmark_payload" or key == "master":
continue
w = config_json[key]
w = config[key]
try:
worker = self.add_worker(
uuid=key,
@ -536,3 +532,15 @@ class World:
except KeyError:
logger.error(f"invalid configuration in file for worker {key}... ignoring")
continue
def save_config(self):
config = {}
config.update({'benchmark_payload': sh.benchmark_payload})
for worker in self._workers:
config.update(worker.info())
with open(self.worker_info_path, 'w+') as worker_info_file:
json.dump(config, worker_info_file, indent=3)
logger.debug(f"config saved")