Fix redundantly init'ing World. Fix typo. Fix a couple vars not being in scope. Add some internal __str__ overrides for debugging. Improve regression by limiting the number of samples by 5. Filter some influential outliers by checking if the variance is astronomically high.

pull/2/head
papuSpartan 2023-05-17 06:50:53 -05:00
parent 6ffef54e36
commit b4c24d75d8
4 changed files with 121 additions and 40 deletions

View File

@ -23,4 +23,3 @@ def preload(parser):
help="Enable debug information",
action="store_true"
)
# args = parser.parse_args()

View File

@ -21,7 +21,7 @@ from pathlib import Path
import os
import subprocess
from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
from scripts.spartan.Worker import Worker
from scripts.spartan.Worker import Worker, State
from modules.shared import opts
# TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
@ -60,13 +60,16 @@ class Script(scripts.Script):
with gradio.Tab('Status') as status_tab:
status = gradio.Textbox(elem_id='status', show_label=False)
status_tab.select(fn=Script.ui_connect_test, inputs=[], outputs=[status])
status_tab.select(fn=Script.ui_connect_status, inputs=[], outputs=[status])
jobs = gradio.Textbox(elem_id='jobs', label='Jobs', show_label=True)
# status_tab.select(fn=Script.world.__str__, inputs=[], outputs=[jobs, status]),
refresh_status_btn = gradio.Button(value='Refresh')
refresh_status_btn.style(size='sm')
refresh_status_btn.click(Script.ui_connect_test, inputs=[], outputs=[status])
refresh_status_btn.click(Script.ui_connect_status, inputs=[], outputs=[jobs, status])
with gradio.Tab('Remote Utils'):
with gradio.Tab('Utils'):
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
refresh_checkpoints_btn.style(full_width=False)
refresh_checkpoints_btn.click(Script.ui_connect_refresh_ckpts_btn, inputs=[], outputs=[])
@ -79,8 +82,19 @@ class Script(scripts.Script):
interrupt_all_btn.style(full_width=False)
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
# redo benchmarks button
redo_benchmarks_btn = gradio.Button(value='Redo benchmarks')
redo_benchmarks_btn.style(full_width=False)
redo_benchmarks_btn.click(Script.ui_connect_benchmark_button, inputs=[], outputs=[])
return
@staticmethod
def ui_connect_benchmark_button():
print("Redoing benchmarks...")
Script.world.benchmark(rebenchmark=True)
@staticmethod
def user_sync_script():
user_scripts = Path(os.path.abspath(__file__)).parent.joinpath('user')
@ -120,23 +134,29 @@ class Script(scripts.Script):
print("Distributed system not initialized")
@staticmethod
def ui_connect_test():
def ui_connect_status():
try:
temp = ''
worker_status = ''
for worker in Script.world.workers:
if worker.master:
continue
temp += f"{worker.uuid} at {worker.address} is {worker.state.name}\n"
worker_status += f"{worker.uuid} at {worker.address} is {worker.state.name}\n"
return temp
# TODO replace this with a single check to a state flag that we should make in the world class
for worker in Script.world.workers:
if worker.state == State.WORKING:
return Script.world.__str__(), worker_status
return 'No active jobs!', worker_status
# init system if it isn't already
except AttributeError:
# batch size will be clobbered later once an actual request is made anyway so I just pass 1
except AttributeError as e:
print(e)
# batch size will be clobbered later once an actual request is made anyway
Script.initialize(initial_payload=None)
Script.ui_connect_test()
return 'refresh!', 'refresh!'
@staticmethod
@ -222,33 +242,38 @@ class Script(scripts.Script):
@staticmethod
def initialize(initial_payload):
if Script.verify_remotes is False:
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# get default batch size
try:
batch_size = initial_payload.batch_size
except AttributeError:
batch_size = 1
Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes)
if Script.world is None:
if Script.verify_remotes is False:
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# add workers to the world
for worker in cmd_opts.distributed_remotes:
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
# construct World
Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes)
try:
Script.world.initialize(batch_size)
print("World initialized!")
except WorldAlreadyInitialized:
Script.world.update_world(total_batch_size=batch_size)
# add workers to the world
for worker in cmd_opts.distributed_remotes:
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
else:
# update world or initialize and update if necessary
try:
Script.world.initialize(batch_size)
print("World initialized!")
except WorldAlreadyInitialized:
Script.world.update_world(total_batch_size=batch_size)
def run(self, p, *args):
if cmd_opts.distributed_remotes is None:
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
# register gallery callback
script_callbacks.on_after_image_processed(Script.add_to_gallery)
script_callbacks.on_after_batch_processed(Script.add_to_gallery)
Script.initialize(initial_payload=p)

View File

@ -121,7 +121,7 @@ class Worker:
These things are used to draw certain conclusions after the first session.
Args:
benchmark_payload (dict): The payload used the benchmark.
benchmark_payload (dict): The payload used in the benchmark.
Returns:
dict: Worker info, including how it was benchmarked.
@ -137,9 +137,18 @@ class Worker:
d[self.uuid] = data
return d
def treat_mpe(self):
"""
In collecting percent errors to calculate the MPE, there may be influential outliers that skew the results.
Here we cull those outliers from them in order to get a more accurate end result.
"""
# TODO implement this
pass
def eta_mpe(self):
"""
Returns the mean absolute percent error using all the currently stored eta percent errors.
Returns the mean percent error using all the currently stored eta percent errors.
Returns:
mpe (float): The mean percent error of a worker's calculation estimates.
@ -165,7 +174,7 @@ class Worker:
"""
# TODO check if using http or https
return f"https://{self.__str__()}/sdapi/v1/{route}"
return f"http://{self.__str__()}/sdapi/v1/{route}"
def batch_eta_hr(self, payload: dict) -> float:
"""
@ -233,13 +242,21 @@ class Worker:
if len(self.eta_percent_error) > 0:
correction = eta * (self.eta_mpe() / 100)
if cmd_opts.distributed_debug:
print(f"worker '{self.uuid}'s last ETA was off by {correction}%")
# if abs(correction) > 300:
# print(f"correction {abs(correction)} exceeds 300%... .")
if cmd_opts.distributed_debug:
print(f"worker '{self.uuid}'s last ETA was off by {correction:.2f}%")
print(f"{self.uuid} eta before correction: ", eta)
# do regression
if correction > 0:
eta += correction
else:
eta -= correction
else:
eta += correction
if cmd_opts.distributed_debug:
print(f"{self.uuid} eta after correction: ", eta)
return eta
except Exception as e:
@ -308,10 +325,18 @@ class Worker:
print(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n")
print(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
if self.eta_percent_error == 0:
self.eta_percent_error[0] = variance
# if the variance is greater than 500% then we ignore it to prevent variation inflation
if abs(variance) < 500:
# check if there are already 5 samples and if so, remove the oldest
# this should help adjust to the user changing tasks
if len(self.eta_percent_error) > 4:
self.eta_percent_error.pop(0)
if self.eta_percent_error == 0: # init
self.eta_percent_error[0] = variance
else: # normal case
self.eta_percent_error.append(variance)
else:
self.eta_percent_error.append(variance)
print(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
except Exception as e:
if payload['batch_size'] == 0:

View File

@ -51,6 +51,14 @@ class Job:
self.batch_size: int = batch_size
self.complementary: bool = False
def __str__(self):
prefix = ''
suffix = f"Job: {self.batch_size} images. Owned by '{self.worker.uuid}'. Rate: {self.worker.avg_ipm}ipm"
if self.complementary:
prefix = "(complementary) "
return prefix + suffix
class World:
"""
@ -70,7 +78,7 @@ class World:
self.total_batch_size: int = 0
self.workers: List[Worker] = [master_worker]
self.jobs: List[Job] = []
self.job_timeout: int = 10 # seconds
self.job_timeout: int = 0 # seconds
self.initialized: bool = False
self.verify_remotes = verify_remotes
self.initial_payload = copy.copy(initial_payload)
@ -188,15 +196,19 @@ class World:
t = Thread(target=worker.refresh_checkpoints, args=())
t.start()
def benchmark(self):
def benchmark(self, rebenchmark: bool = False):
"""
Attempts to benchmark all workers a part of the world.
"""
global benchmark_payload
workers_info: dict = {}
saved: bool = os.path.exists(self.worker_info_path)
benchmark_payload_loaded: bool = False
if rebenchmark:
saved = False
if saved:
workers_info = json.load(open(self.worker_info_path, 'r'))
@ -250,6 +262,14 @@ class World:
print(f"{i}. worker '{worker}' - {worker.avg_ipm} ipm")
i += 1
def __str__(self):
# print status of all jobs
jobs_str = ""
for job in self.jobs:
jobs_str += job.__str__() + "\n"
return jobs_str
def realtime_jobs(self) -> List[Job]:
"""
Determines which jobs are considered real-time by checking which jobs are not(complementary).
@ -263,6 +283,8 @@ class World:
if job.complementary is False:
fast_jobs.append(job)
print(f"fast jobs: {fast_jobs}")
return fast_jobs
def slowest_realtime_job(self) -> Job:
@ -292,6 +314,10 @@ class World:
"""
fastest_worker = self.fastest_realtime_job().worker
# if the worker is the fastest, then there is no lag
if worker == fastest_worker:
return 0
lag = worker.batch_eta(payload=payload) - fastest_worker.batch_eta(payload=payload)
return lag
@ -304,6 +330,7 @@ class World:
Returns:
float: Local worker speed in ipm
"""
global benchmark_payload
master_bench_payload = copy.copy(self.initial_payload)
@ -357,14 +384,14 @@ class World:
deferred_images = 0 # the number of images that were not assigned to a worker due to the worker being too slow
# the maximum amount of images that a "slow" worker can produce in the slack space where other nodes are working
max_compensation = 4
# max_compensation = 4 currently unused
images_per_job = None
for job in self.jobs:
lag = self.job_stall(job.worker, payload=payload)
if lag < self.job_timeout:
if lag < self.job_timeout or lag == 0:
job.batch_size = payload['batch_size']
continue
@ -396,6 +423,9 @@ class World:
slowest_active_worker = self.slowest_realtime_job().worker
slack_time = slowest_active_worker.batch_eta(payload=payload)
if cmd_opts.distributed_debug:
print(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.uuid}'")
# in the case that this worker is now taking on what others workers would have been (if they were real-time)
# this means that there will be more slack time for complementary nodes
slack_time = slack_time + ((slack_time / payload['batch_size']) * images_per_job)
@ -412,6 +442,8 @@ class World:
# It might be better to just inject a black image. (if master is that slow)
master_job = self.master_job()
if master_job.batch_size < 1:
if cmd_opts.distributed_debug:
print("Master couldn't keep up... defaulting to 1 image")
master_job.batch_size = 1
print("After job optimization, job layout is the following:")