add ability to remove workers from within the UI. fix regression preventing benchmarking when config exists but recorded ipm values are invalid.
parent
80852253c9
commit
3f38d463ac
|
|
@ -4,12 +4,15 @@ import subprocess
|
|||
from pathlib import Path
|
||||
import gradio
|
||||
from scripts.spartan.shared import logger, log_level
|
||||
from scripts.spartan.Worker import State
|
||||
from scripts.spartan.Worker import State, Worker
|
||||
from modules.shared import state as webui_state
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
worker_select_dropdown = None
|
||||
|
||||
class UI:
|
||||
|
||||
def __init__(self, script, world):
|
||||
self.script = script
|
||||
self.world = world
|
||||
|
|
@ -73,23 +76,36 @@ class UI:
|
|||
logger.debug(f"job timeout is now {job_timeout} seconds")
|
||||
|
||||
def save_worker_btn(self, name, address, port, tls):
|
||||
worker = self.world.add_worker(name, address, port, tls)
|
||||
self.world.add_worker(name, address, port, tls)
|
||||
self.world.save_config()
|
||||
|
||||
workers_info = {}
|
||||
with open(self.world.worker_info_path, 'r', encoding='utf-8') as worker_info_file:
|
||||
try:
|
||||
workers_info = json.load(worker_info_file)
|
||||
except json.decoder.JSONDecodeError:
|
||||
logger.error(f"corrupt or invalid config file... ignoring")
|
||||
except io.UnsupportedOperation:
|
||||
pass
|
||||
# visibly update which workers can be selected
|
||||
labels = [x.uuid for x in self.selectable_remote_workers()]
|
||||
return gradio.Dropdown.update(choices=labels)
|
||||
|
||||
with open(self.world.worker_info_path, 'w', encoding='utf-8') as worker_info_file:
|
||||
inf: dict = worker.info()
|
||||
workers_info[name] = inf[name]
|
||||
def selectable_remote_workers(self) -> List[Worker]:
|
||||
remote_workers = []
|
||||
|
||||
json.dump(workers_info, worker_info_file, indent=3)
|
||||
for worker in self.world.get_workers():
|
||||
if worker.master:
|
||||
continue
|
||||
remote_workers.append(worker)
|
||||
remote_workers = sorted(remote_workers, key=lambda x: x.uuid)
|
||||
|
||||
return remote_workers
|
||||
|
||||
def remove_worker_btn(self, worker_label):
|
||||
# remove worker from memory
|
||||
for worker in self.world._workers:
|
||||
if worker.uuid == worker_label:
|
||||
self.world._workers.remove(worker)
|
||||
|
||||
# remove worker from disk
|
||||
self.world.save_config()
|
||||
|
||||
# visibly update which workers can be selected
|
||||
labels = [x.uuid for x in self.selectable_remote_workers()]
|
||||
return gradio.Dropdown.update(choices=labels)
|
||||
|
||||
# end handlers
|
||||
|
||||
|
|
@ -135,16 +151,25 @@ class UI:
|
|||
clear_queue_btn.click(self.clear_queue_btn)
|
||||
|
||||
with gradio.Tab('Worker Config'):
|
||||
worker_name_field = gradio.Textbox(label='Name')
|
||||
worker_select_dropdown = None
|
||||
|
||||
worker_select_dropdown = gradio.Dropdown(
|
||||
[x.uuid for x in self.selectable_remote_workers()],
|
||||
info='Select a pre-existing worker or enter a label for a new one',
|
||||
label='Label',
|
||||
allow_custom_value=True
|
||||
)
|
||||
worker_address_field = gradio.Textbox(label='Address')
|
||||
worker_port_field = gradio.Textbox(label='Port', value='7860')
|
||||
worker_tls_cbx = gradio.Checkbox(
|
||||
label='connect to worker using https'
|
||||
)
|
||||
save_worker_btn = gradio.Button(
|
||||
value='Add Worker'
|
||||
)
|
||||
save_worker_btn.click(self.save_worker_btn, inputs=[worker_name_field, worker_address_field, worker_port_field, worker_tls_cbx])
|
||||
|
||||
with gradio.Row():
|
||||
save_worker_btn = gradio.Button(value='Add/Update Worker')
|
||||
save_worker_btn.click(self.save_worker_btn, inputs=[worker_select_dropdown, worker_address_field, worker_port_field, worker_tls_cbx], outputs=[worker_select_dropdown])
|
||||
remove_worker_btn = gradio.Button(value='Remove Worker', variant='stop')
|
||||
remove_worker_btn.click(self.remove_worker_btn, inputs=worker_select_dropdown, outputs=[worker_select_dropdown])
|
||||
|
||||
with gradio.Tab('Settings'):
|
||||
thin_client_cbx = gradio.Checkbox(
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ class World:
|
|||
def __init__(self, initial_payload, verify_remotes: bool = True):
|
||||
self.master_worker = Worker(master=True)
|
||||
self.total_batch_size: int = 0
|
||||
self.__workers: List[Worker] = [self.master_worker]
|
||||
self._workers: List[Worker] = [self.master_worker]
|
||||
self.jobs: List[Job] = []
|
||||
self.job_timeout: int = 6 # seconds
|
||||
self.initialized: bool = False
|
||||
|
|
@ -149,10 +149,10 @@ class World:
|
|||
|
||||
worker = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes, tls=tls)
|
||||
|
||||
for w in self.__workers:
|
||||
for w in self._workers:
|
||||
if w.uuid == uuid:
|
||||
self.__workers.remove(w)
|
||||
self.__workers.append(worker)
|
||||
self._workers.remove(w)
|
||||
self._workers.append(worker)
|
||||
|
||||
return worker
|
||||
|
||||
|
|
@ -220,18 +220,17 @@ class World:
|
|||
if saved and not rebenchmark:
|
||||
logger.debug(f"loaded saved configuration: \n{workers_info}")
|
||||
|
||||
for worker in self.__workers:
|
||||
for worker in self._workers:
|
||||
try:
|
||||
worker.avg_ipm = workers_info[worker.uuid]['avg_ipm']
|
||||
if worker.avg_ipm <= 0:
|
||||
logger.debug(f"{worker.uuid} has recorded ipm of 0... marking as unbenched")
|
||||
if worker.avg_ipm is None or worker.avg_ipm <= 0:
|
||||
logger.debug(f"{worker.uuid} recorded ipm is invalid... marking as unbenched")
|
||||
unbenched_workers.append(worker)
|
||||
else:
|
||||
worker.benchmarked = True
|
||||
except KeyError:
|
||||
logger.debug(f"worker '{worker.uuid}' not found in workers.json")
|
||||
unbenched_workers.append(worker)
|
||||
return
|
||||
else:
|
||||
unbenched_workers = self.get_workers()
|
||||
|
||||
|
|
@ -244,19 +243,13 @@ class World:
|
|||
|
||||
# wait for all benchmarks to finish and update stats on newly benchmarked workers
|
||||
if len(benchmark_threads) > 0:
|
||||
with open(self.worker_info_path, 'w') as worker_info_file:
|
||||
for t in benchmark_threads:
|
||||
t.join()
|
||||
logger.info("Benchmarking finished")
|
||||
for t in benchmark_threads:
|
||||
t.join()
|
||||
logger.info("Benchmarking finished")
|
||||
|
||||
for worker in unbenched_workers:
|
||||
workers_info.update(worker.info())
|
||||
workers_info.update({'benchmark_payload': sh.benchmark_payload})
|
||||
|
||||
# save benchmark results to workers.json
|
||||
json.dump(workers_info, worker_info_file, indent=3)
|
||||
|
||||
logger.info(self.speed_summary())
|
||||
# save benchmark results to workers.json
|
||||
self.save_config()
|
||||
logger.info(self.speed_summary())
|
||||
|
||||
def get_current_output_size(self) -> int:
|
||||
"""
|
||||
|
|
@ -274,7 +267,7 @@ class World:
|
|||
"""
|
||||
Returns string listing workers by their ipm in descending order.
|
||||
"""
|
||||
workers_copy = copy.deepcopy(self.__workers)
|
||||
workers_copy = copy.deepcopy(self._workers)
|
||||
workers_copy.sort(key=lambda w: w.avg_ipm, reverse=True)
|
||||
|
||||
total_ipm = 0
|
||||
|
|
@ -393,7 +386,7 @@ class World:
|
|||
|
||||
def get_workers(self):
|
||||
filtered = []
|
||||
for worker in self.__workers:
|
||||
for worker in self._workers:
|
||||
if worker.avg_ipm is not None and worker.avg_ipm <= 0:
|
||||
logger.warning(f"config reports invalid speed (0 ipm) for worker '{worker.uuid}', setting default of 1 ipm.\nplease re-benchmark")
|
||||
worker.avg_ipm = 1
|
||||
|
|
@ -505,7 +498,7 @@ class World:
|
|||
del self.jobs[last]
|
||||
last -= 1
|
||||
|
||||
def load_config(self):
|
||||
def config(self) -> json:
|
||||
if not os.path.exists(self.worker_info_path):
|
||||
logger.debug(f"Config was not found at '{self.worker_info_path}'")
|
||||
return
|
||||
|
|
@ -513,16 +506,19 @@ class World:
|
|||
with open(self.worker_info_path, 'r') as config:
|
||||
|
||||
try:
|
||||
config_json = json.load(config)
|
||||
return json.load(config)
|
||||
except json.decoder.JSONDecodeError:
|
||||
logger.debug(f"config is corrupt or invalid JSON, unable to load")
|
||||
return
|
||||
|
||||
for key in config_json:
|
||||
def load_config(self):
|
||||
config = self.config()
|
||||
|
||||
if config is not None:
|
||||
for key in config:
|
||||
if key == "benchmark_payload" or key == "master":
|
||||
continue
|
||||
|
||||
w = config_json[key]
|
||||
w = config[key]
|
||||
try:
|
||||
worker = self.add_worker(
|
||||
uuid=key,
|
||||
|
|
@ -536,3 +532,15 @@ class World:
|
|||
except KeyError:
|
||||
logger.error(f"invalid configuration in file for worker {key}... ignoring")
|
||||
continue
|
||||
|
||||
def save_config(self):
|
||||
config = {}
|
||||
|
||||
config.update({'benchmark_payload': sh.benchmark_payload})
|
||||
for worker in self._workers:
|
||||
config.update(worker.info())
|
||||
|
||||
with open(self.worker_info_path, 'w+') as worker_info_file:
|
||||
json.dump(config, worker_info_file, indent=3)
|
||||
logger.debug(f"config saved")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue