diff --git a/scripts/spartan/UI.py b/scripts/spartan/UI.py index 2ef2460..2474354 100644 --- a/scripts/spartan/UI.py +++ b/scripts/spartan/UI.py @@ -4,12 +4,15 @@ import subprocess from pathlib import Path import gradio from scripts.spartan.shared import logger, log_level -from scripts.spartan.Worker import State +from scripts.spartan.Worker import State, Worker from modules.shared import state as webui_state import json +from typing import List +worker_select_dropdown = None class UI: + def __init__(self, script, world): self.script = script self.world = world @@ -73,23 +76,36 @@ class UI: logger.debug(f"job timeout is now {job_timeout} seconds") def save_worker_btn(self, name, address, port, tls): - worker = self.world.add_worker(name, address, port, tls) + self.world.add_worker(name, address, port, tls) + self.world.save_config() - workers_info = {} - with open(self.world.worker_info_path, 'r', encoding='utf-8') as worker_info_file: - try: - workers_info = json.load(worker_info_file) - except json.decoder.JSONDecodeError: - logger.error(f"corrupt or invalid config file... ignoring") - except io.UnsupportedOperation: - pass + # visibly update which workers can be selected + labels = [x.uuid for x in self.selectable_remote_workers()] + return gradio.Dropdown.update(choices=labels) - with open(self.world.worker_info_path, 'w', encoding='utf-8') as worker_info_file: - inf: dict = worker.info() - workers_info[name] = inf[name] + def selectable_remote_workers(self) -> List[Worker]: + remote_workers = [] - json.dump(workers_info, worker_info_file, indent=3) + for worker in self.world.get_workers(): + if worker.master: + continue + remote_workers.append(worker) + remote_workers = sorted(remote_workers, key=lambda x: x.uuid) + return remote_workers + + def remove_worker_btn(self, worker_label): + # remove worker from memory + for worker in self.world._workers: + if worker.uuid == worker_label: + self.world._workers.remove(worker) + + # remove worker from disk + self.world.save_config() + + # visibly update which workers can be selected + labels = [x.uuid for x in self.selectable_remote_workers()] + return gradio.Dropdown.update(choices=labels) # end handlers @@ -135,16 +151,25 @@ class UI: clear_queue_btn.click(self.clear_queue_btn) with gradio.Tab('Worker Config'): - worker_name_field = gradio.Textbox(label='Name') + worker_select_dropdown = None + + worker_select_dropdown = gradio.Dropdown( + [x.uuid for x in self.selectable_remote_workers()], + info='Select a pre-existing worker or enter a label for a new one', + label='Label', + allow_custom_value=True + ) worker_address_field = gradio.Textbox(label='Address') worker_port_field = gradio.Textbox(label='Port', value='7860') worker_tls_cbx = gradio.Checkbox( label='connect to worker using https' ) - save_worker_btn = gradio.Button( - value='Add Worker' - ) - save_worker_btn.click(self.save_worker_btn, inputs=[worker_name_field, worker_address_field, worker_port_field, worker_tls_cbx]) + + with gradio.Row(): + save_worker_btn = gradio.Button(value='Add/Update Worker') + save_worker_btn.click(self.save_worker_btn, inputs=[worker_select_dropdown, worker_address_field, worker_port_field, worker_tls_cbx], outputs=[worker_select_dropdown]) + remove_worker_btn = gradio.Button(value='Remove Worker', variant='stop') + remove_worker_btn.click(self.remove_worker_btn, inputs=worker_select_dropdown, outputs=[worker_select_dropdown]) with gradio.Tab('Settings'): thin_client_cbx = gradio.Checkbox( diff --git a/scripts/spartan/World.py b/scripts/spartan/World.py index 43adcc6..5aafb20 100644 --- a/scripts/spartan/World.py +++ b/scripts/spartan/World.py @@ -75,7 +75,7 @@ class World: def __init__(self, initial_payload, verify_remotes: bool = True): self.master_worker = Worker(master=True) self.total_batch_size: int = 0 - self.__workers: List[Worker] = [self.master_worker] + self._workers: List[Worker] = [self.master_worker] self.jobs: List[Job] = [] self.job_timeout: int = 6 # seconds self.initialized: bool = False @@ -149,10 +149,10 @@ class World: worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes, tls=tls) - for w in self.__workers: + for w in self._workers: if w.uuid == uuid: - self.__workers.remove(w) - self.__workers.append(worker) + self._workers.remove(w) + self._workers.append(worker) return worker @@ -220,18 +220,17 @@ class World: if saved and not rebenchmark: logger.debug(f"loaded saved configuration: \n{workers_info}") - for worker in self.__workers: + for worker in self._workers: try: worker.avg_ipm = workers_info[worker.uuid]['avg_ipm'] - if worker.avg_ipm <= 0: - logger.debug(f"{worker.uuid} has recorded ipm of 0... marking as unbenched") + if worker.avg_ipm is None or worker.avg_ipm <= 0: + logger.debug(f"{worker.uuid} recorded ipm is invalid... marking as unbenched") unbenched_workers.append(worker) else: worker.benchmarked = True except KeyError: logger.debug(f"worker '{worker.uuid}' not found in workers.json") unbenched_workers.append(worker) - return else: unbenched_workers = self.get_workers() @@ -244,19 +243,13 @@ class World: # wait for all benchmarks to finish and update stats on newly benchmarked workers if len(benchmark_threads) > 0: - with open(self.worker_info_path, 'w') as worker_info_file: - for t in benchmark_threads: - t.join() - logger.info("Benchmarking finished") + for t in benchmark_threads: + t.join() + logger.info("Benchmarking finished") - for worker in unbenched_workers: - workers_info.update(worker.info()) - workers_info.update({'benchmark_payload': sh.benchmark_payload}) - - # save benchmark results to workers.json - json.dump(workers_info, worker_info_file, indent=3) - - logger.info(self.speed_summary()) + # save benchmark results to workers.json + self.save_config() + logger.info(self.speed_summary()) def get_current_output_size(self) -> int: """ @@ -274,7 +267,7 @@ class World: """ Returns string listing workers by their ipm in descending order. """ - workers_copy = copy.deepcopy(self.__workers) + workers_copy = copy.deepcopy(self._workers) workers_copy.sort(key=lambda w: w.avg_ipm, reverse=True) total_ipm = 0 @@ -393,7 +386,7 @@ class World: def get_workers(self): filtered = [] - for worker in self.__workers: + for worker in self._workers: if worker.avg_ipm is not None and worker.avg_ipm <= 0: logger.warning(f"config reports invalid speed (0 ipm) for worker '{worker.uuid}', setting default of 1 ipm.\nplease re-benchmark") worker.avg_ipm = 1 @@ -505,7 +498,7 @@ class World: del self.jobs[last] last -= 1 - def load_config(self): + def config(self) -> json: if not os.path.exists(self.worker_info_path): logger.debug(f"Config was not found at '{self.worker_info_path}'") return @@ -513,16 +506,19 @@ class World: with open(self.worker_info_path, 'r') as config: try: - config_json = json.load(config) + return json.load(config) except json.decoder.JSONDecodeError: logger.debug(f"config is corrupt or invalid JSON, unable to load") - return - for key in config_json: + def load_config(self): + config = self.config() + + if config is not None: + for key in config: if key == "benchmark_payload" or key == "master": continue - w = config_json[key] + w = config[key] try: worker = self.add_worker( uuid=key, @@ -536,3 +532,15 @@ class World: except KeyError: logger.error(f"invalid configuration in file for worker {key}... ignoring") continue + + def save_config(self): + config = {} + + config.update({'benchmark_payload': sh.benchmark_payload}) + for worker in self._workers: + config.update(worker.info()) + + with open(self.worker_info_path, 'w+') as worker_info_file: + json.dump(config, worker_info_file, indent=3) + logger.debug(f"config saved") +