Fix redundantly init'ing World. Fix typo. Fix a couple vars not being in scope. Add some internal __str__ overrides for debugging. Improve regression by limiting the number of samples by 5. Filter some influential outliers by checking if the variance is astronomically high.

pull/2/head
papuSpartan 2023-05-17 06:50:53 -05:00
parent 6ffef54e36
commit b4c24d75d8
4 changed files with 121 additions and 40 deletions

View File

@ -23,4 +23,3 @@ def preload(parser):
help="Enable debug information", help="Enable debug information",
action="store_true" action="store_true"
) )
# args = parser.parse_args()

View File

@ -21,7 +21,7 @@ from pathlib import Path
import os import os
import subprocess import subprocess
from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
from scripts.spartan.Worker import Worker from scripts.spartan.Worker import Worker, State
from modules.shared import opts from modules.shared import opts
# TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers? # TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
@ -60,13 +60,16 @@ class Script(scripts.Script):
with gradio.Tab('Status') as status_tab: with gradio.Tab('Status') as status_tab:
status = gradio.Textbox(elem_id='status', show_label=False) status = gradio.Textbox(elem_id='status', show_label=False)
status_tab.select(fn=Script.ui_connect_test, inputs=[], outputs=[status]) status_tab.select(fn=Script.ui_connect_status, inputs=[], outputs=[status])
jobs = gradio.Textbox(elem_id='jobs', label='Jobs', show_label=True)
# status_tab.select(fn=Script.world.__str__, inputs=[], outputs=[jobs, status]),
refresh_status_btn = gradio.Button(value='Refresh') refresh_status_btn = gradio.Button(value='Refresh')
refresh_status_btn.style(size='sm') refresh_status_btn.style(size='sm')
refresh_status_btn.click(Script.ui_connect_test, inputs=[], outputs=[status]) refresh_status_btn.click(Script.ui_connect_status, inputs=[], outputs=[jobs, status])
with gradio.Tab('Remote Utils'): with gradio.Tab('Utils'):
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints') refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
refresh_checkpoints_btn.style(full_width=False) refresh_checkpoints_btn.style(full_width=False)
refresh_checkpoints_btn.click(Script.ui_connect_refresh_ckpts_btn, inputs=[], outputs=[]) refresh_checkpoints_btn.click(Script.ui_connect_refresh_ckpts_btn, inputs=[], outputs=[])
@ -79,8 +82,19 @@ class Script(scripts.Script):
interrupt_all_btn.style(full_width=False) interrupt_all_btn.style(full_width=False)
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[]) interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
# redo benchmarks button
redo_benchmarks_btn = gradio.Button(value='Redo benchmarks')
redo_benchmarks_btn.style(full_width=False)
redo_benchmarks_btn.click(Script.ui_connect_benchmark_button, inputs=[], outputs=[])
return return
@staticmethod
def ui_connect_benchmark_button():
print("Redoing benchmarks...")
Script.world.benchmark(rebenchmark=True)
@staticmethod @staticmethod
def user_sync_script(): def user_sync_script():
user_scripts = Path(os.path.abspath(__file__)).parent.joinpath('user') user_scripts = Path(os.path.abspath(__file__)).parent.joinpath('user')
@ -120,23 +134,29 @@ class Script(scripts.Script):
print("Distributed system not initialized") print("Distributed system not initialized")
@staticmethod @staticmethod
def ui_connect_test(): def ui_connect_status():
try: try:
temp = '' worker_status = ''
for worker in Script.world.workers: for worker in Script.world.workers:
if worker.master: if worker.master:
continue continue
temp += f"{worker.uuid} at {worker.address} is {worker.state.name}\n" worker_status += f"{worker.uuid} at {worker.address} is {worker.state.name}\n"
return temp # TODO replace this with a single check to a state flag that we should make in the world class
for worker in Script.world.workers:
if worker.state == State.WORKING:
return Script.world.__str__(), worker_status
return 'No active jobs!', worker_status
# init system if it isn't already # init system if it isn't already
except AttributeError: except AttributeError as e:
# batch size will be clobbered later once an actual request is made anyway so I just pass 1 print(e)
# batch size will be clobbered later once an actual request is made anyway
Script.initialize(initial_payload=None) Script.initialize(initial_payload=None)
Script.ui_connect_test() return 'refresh!', 'refresh!'
@staticmethod @staticmethod
@ -222,33 +242,38 @@ class Script(scripts.Script):
@staticmethod @staticmethod
def initialize(initial_payload): def initialize(initial_payload):
if Script.verify_remotes is False: # get default batch size
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
try: try:
batch_size = initial_payload.batch_size batch_size = initial_payload.batch_size
except AttributeError: except AttributeError:
batch_size = 1 batch_size = 1
Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes) if Script.world is None:
if Script.verify_remotes is False:
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# add workers to the world # construct World
for worker in cmd_opts.distributed_remotes: Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes)
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
try: # add workers to the world
Script.world.initialize(batch_size) for worker in cmd_opts.distributed_remotes:
print("World initialized!") Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
except WorldAlreadyInitialized:
Script.world.update_world(total_batch_size=batch_size) else:
# update world or initialize and update if necessary
try:
Script.world.initialize(batch_size)
print("World initialized!")
except WorldAlreadyInitialized:
Script.world.update_world(total_batch_size=batch_size)
def run(self, p, *args): def run(self, p, *args):
if cmd_opts.distributed_remotes is None: if cmd_opts.distributed_remotes is None:
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)") raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
# register gallery callback # register gallery callback
script_callbacks.on_after_image_processed(Script.add_to_gallery) script_callbacks.on_after_batch_processed(Script.add_to_gallery)
Script.initialize(initial_payload=p) Script.initialize(initial_payload=p)

View File

@ -121,7 +121,7 @@ class Worker:
These things are used to draw certain conclusions after the first session. These things are used to draw certain conclusions after the first session.
Args: Args:
benchmark_payload (dict): The payload used the benchmark. benchmark_payload (dict): The payload used in the benchmark.
Returns: Returns:
dict: Worker info, including how it was benchmarked. dict: Worker info, including how it was benchmarked.
@ -137,9 +137,18 @@ class Worker:
d[self.uuid] = data d[self.uuid] = data
return d return d
def treat_mpe(self):
"""
In collecting percent errors to calculate the MPE, there may be influential outliers that skew the results.
Here we cull those outliers from them in order to get a more accurate end result.
"""
# TODO implement this
pass
def eta_mpe(self): def eta_mpe(self):
""" """
Returns the mean absolute percent error using all the currently stored eta percent errors. Returns the mean percent error using all the currently stored eta percent errors.
Returns: Returns:
mpe (float): The mean percent error of a worker's calculation estimates. mpe (float): The mean percent error of a worker's calculation estimates.
@ -165,7 +174,7 @@ class Worker:
""" """
# TODO check if using http or https # TODO check if using http or https
return f"https://{self.__str__()}/sdapi/v1/{route}" return f"http://{self.__str__()}/sdapi/v1/{route}"
def batch_eta_hr(self, payload: dict) -> float: def batch_eta_hr(self, payload: dict) -> float:
""" """
@ -233,13 +242,21 @@ class Worker:
if len(self.eta_percent_error) > 0: if len(self.eta_percent_error) > 0:
correction = eta * (self.eta_mpe() / 100) correction = eta * (self.eta_mpe() / 100)
if cmd_opts.distributed_debug: # if abs(correction) > 300:
print(f"worker '{self.uuid}'s last ETA was off by {correction}%") # print(f"correction {abs(correction)} exceeds 300%... .")
if cmd_opts.distributed_debug:
print(f"worker '{self.uuid}'s last ETA was off by {correction:.2f}%")
print(f"{self.uuid} eta before correction: ", eta)
# do regression
if correction > 0: if correction > 0:
eta += correction
else:
eta -= correction eta -= correction
else:
eta += correction
if cmd_opts.distributed_debug:
print(f"{self.uuid} eta after correction: ", eta)
return eta return eta
except Exception as e: except Exception as e:
@ -308,10 +325,18 @@ class Worker:
print(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n") print(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n")
print(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n") print(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
if self.eta_percent_error == 0: # if the variance is greater than 500% then we ignore it to prevent variation inflation
self.eta_percent_error[0] = variance if abs(variance) < 500:
# check if there are already 5 samples and if so, remove the oldest
# this should help adjust to the user changing tasks
if len(self.eta_percent_error) > 4:
self.eta_percent_error.pop(0)
if self.eta_percent_error == 0: # init
self.eta_percent_error[0] = variance
else: # normal case
self.eta_percent_error.append(variance)
else: else:
self.eta_percent_error.append(variance) print(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
except Exception as e: except Exception as e:
if payload['batch_size'] == 0: if payload['batch_size'] == 0:

View File

@ -51,6 +51,14 @@ class Job:
self.batch_size: int = batch_size self.batch_size: int = batch_size
self.complementary: bool = False self.complementary: bool = False
def __str__(self):
prefix = ''
suffix = f"Job: {self.batch_size} images. Owned by '{self.worker.uuid}'. Rate: {self.worker.avg_ipm}ipm"
if self.complementary:
prefix = "(complementary) "
return prefix + suffix
class World: class World:
""" """
@ -70,7 +78,7 @@ class World:
self.total_batch_size: int = 0 self.total_batch_size: int = 0
self.workers: List[Worker] = [master_worker] self.workers: List[Worker] = [master_worker]
self.jobs: List[Job] = [] self.jobs: List[Job] = []
self.job_timeout: int = 10 # seconds self.job_timeout: int = 0 # seconds
self.initialized: bool = False self.initialized: bool = False
self.verify_remotes = verify_remotes self.verify_remotes = verify_remotes
self.initial_payload = copy.copy(initial_payload) self.initial_payload = copy.copy(initial_payload)
@ -188,15 +196,19 @@ class World:
t = Thread(target=worker.refresh_checkpoints, args=()) t = Thread(target=worker.refresh_checkpoints, args=())
t.start() t.start()
def benchmark(self): def benchmark(self, rebenchmark: bool = False):
""" """
Attempts to benchmark all workers a part of the world. Attempts to benchmark all workers a part of the world.
""" """
global benchmark_payload
workers_info: dict = {} workers_info: dict = {}
saved: bool = os.path.exists(self.worker_info_path) saved: bool = os.path.exists(self.worker_info_path)
benchmark_payload_loaded: bool = False benchmark_payload_loaded: bool = False
if rebenchmark:
saved = False
if saved: if saved:
workers_info = json.load(open(self.worker_info_path, 'r')) workers_info = json.load(open(self.worker_info_path, 'r'))
@ -250,6 +262,14 @@ class World:
print(f"{i}. worker '{worker}' - {worker.avg_ipm} ipm") print(f"{i}. worker '{worker}' - {worker.avg_ipm} ipm")
i += 1 i += 1
def __str__(self):
# print status of all jobs
jobs_str = ""
for job in self.jobs:
jobs_str += job.__str__() + "\n"
return jobs_str
def realtime_jobs(self) -> List[Job]: def realtime_jobs(self) -> List[Job]:
""" """
Determines which jobs are considered real-time by checking which jobs are not(complementary). Determines which jobs are considered real-time by checking which jobs are not(complementary).
@ -263,6 +283,8 @@ class World:
if job.complementary is False: if job.complementary is False:
fast_jobs.append(job) fast_jobs.append(job)
print(f"fast jobs: {fast_jobs}")
return fast_jobs return fast_jobs
def slowest_realtime_job(self) -> Job: def slowest_realtime_job(self) -> Job:
@ -292,6 +314,10 @@ class World:
""" """
fastest_worker = self.fastest_realtime_job().worker fastest_worker = self.fastest_realtime_job().worker
# if the worker is the fastest, then there is no lag
if worker == fastest_worker:
return 0
lag = worker.batch_eta(payload=payload) - fastest_worker.batch_eta(payload=payload) lag = worker.batch_eta(payload=payload) - fastest_worker.batch_eta(payload=payload)
return lag return lag
@ -304,6 +330,7 @@ class World:
Returns: Returns:
float: Local worker speed in ipm float: Local worker speed in ipm
""" """
global benchmark_payload
master_bench_payload = copy.copy(self.initial_payload) master_bench_payload = copy.copy(self.initial_payload)
@ -357,14 +384,14 @@ class World:
deferred_images = 0 # the number of images that were not assigned to a worker due to the worker being too slow deferred_images = 0 # the number of images that were not assigned to a worker due to the worker being too slow
# the maximum amount of images that a "slow" worker can produce in the slack space where other nodes are working # the maximum amount of images that a "slow" worker can produce in the slack space where other nodes are working
max_compensation = 4 # max_compensation = 4 currently unused
images_per_job = None images_per_job = None
for job in self.jobs: for job in self.jobs:
lag = self.job_stall(job.worker, payload=payload) lag = self.job_stall(job.worker, payload=payload)
if lag < self.job_timeout: if lag < self.job_timeout or lag == 0:
job.batch_size = payload['batch_size'] job.batch_size = payload['batch_size']
continue continue
@ -396,6 +423,9 @@ class World:
slowest_active_worker = self.slowest_realtime_job().worker slowest_active_worker = self.slowest_realtime_job().worker
slack_time = slowest_active_worker.batch_eta(payload=payload) slack_time = slowest_active_worker.batch_eta(payload=payload)
if cmd_opts.distributed_debug:
print(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.uuid}'")
# in the case that this worker is now taking on what others workers would have been (if they were real-time) # in the case that this worker is now taking on what others workers would have been (if they were real-time)
# this means that there will be more slack time for complementary nodes # this means that there will be more slack time for complementary nodes
slack_time = slack_time + ((slack_time / payload['batch_size']) * images_per_job) slack_time = slack_time + ((slack_time / payload['batch_size']) * images_per_job)
@ -412,6 +442,8 @@ class World:
# It might be better to just inject a black image. (if master is that slow) # It might be better to just inject a black image. (if master is that slow)
master_job = self.master_job() master_job = self.master_job()
if master_job.batch_size < 1: if master_job.batch_size < 1:
if cmd_opts.distributed_debug:
print("Master couldn't keep up... defaulting to 1 image")
master_job.batch_size = 1 master_job.batch_size = 1
print("After job optimization, job layout is the following:") print("After job optimization, job layout is the following:")