Fix redundantly init'ing World. Fix typo. Fix a couple vars not being in scope. Add some internal __str__ overrides for debugging. Improve regression by limiting the number of samples by 5. Filter some influential outliers by checking if the variance is astronomically high.
parent
6ffef54e36
commit
b4c24d75d8
|
|
@ -23,4 +23,3 @@ def preload(parser):
|
|||
help="Enable debug information",
|
||||
action="store_true"
|
||||
)
|
||||
# args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from pathlib import Path
|
|||
import os
|
||||
import subprocess
|
||||
from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
|
||||
from scripts.spartan.Worker import Worker
|
||||
from scripts.spartan.Worker import Worker, State
|
||||
from modules.shared import opts
|
||||
|
||||
# TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
|
||||
|
|
@ -60,13 +60,16 @@ class Script(scripts.Script):
|
|||
|
||||
with gradio.Tab('Status') as status_tab:
|
||||
status = gradio.Textbox(elem_id='status', show_label=False)
|
||||
status_tab.select(fn=Script.ui_connect_test, inputs=[], outputs=[status])
|
||||
status_tab.select(fn=Script.ui_connect_status, inputs=[], outputs=[status])
|
||||
|
||||
jobs = gradio.Textbox(elem_id='jobs', label='Jobs', show_label=True)
|
||||
# status_tab.select(fn=Script.world.__str__, inputs=[], outputs=[jobs, status]),
|
||||
|
||||
refresh_status_btn = gradio.Button(value='Refresh')
|
||||
refresh_status_btn.style(size='sm')
|
||||
refresh_status_btn.click(Script.ui_connect_test, inputs=[], outputs=[status])
|
||||
refresh_status_btn.click(Script.ui_connect_status, inputs=[], outputs=[jobs, status])
|
||||
|
||||
with gradio.Tab('Remote Utils'):
|
||||
with gradio.Tab('Utils'):
|
||||
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
|
||||
refresh_checkpoints_btn.style(full_width=False)
|
||||
refresh_checkpoints_btn.click(Script.ui_connect_refresh_ckpts_btn, inputs=[], outputs=[])
|
||||
|
|
@ -79,8 +82,19 @@ class Script(scripts.Script):
|
|||
interrupt_all_btn.style(full_width=False)
|
||||
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
|
||||
|
||||
# redo benchmarks button
|
||||
redo_benchmarks_btn = gradio.Button(value='Redo benchmarks')
|
||||
redo_benchmarks_btn.style(full_width=False)
|
||||
redo_benchmarks_btn.click(Script.ui_connect_benchmark_button, inputs=[], outputs=[])
|
||||
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def ui_connect_benchmark_button():
|
||||
print("Redoing benchmarks...")
|
||||
Script.world.benchmark(rebenchmark=True)
|
||||
|
||||
@staticmethod
|
||||
def user_sync_script():
|
||||
user_scripts = Path(os.path.abspath(__file__)).parent.joinpath('user')
|
||||
|
|
@ -120,23 +134,29 @@ class Script(scripts.Script):
|
|||
print("Distributed system not initialized")
|
||||
|
||||
@staticmethod
|
||||
def ui_connect_test():
|
||||
def ui_connect_status():
|
||||
try:
|
||||
temp = ''
|
||||
worker_status = ''
|
||||
|
||||
for worker in Script.world.workers:
|
||||
if worker.master:
|
||||
continue
|
||||
|
||||
temp += f"{worker.uuid} at {worker.address} is {worker.state.name}\n"
|
||||
worker_status += f"{worker.uuid} at {worker.address} is {worker.state.name}\n"
|
||||
|
||||
return temp
|
||||
# TODO replace this with a single check to a state flag that we should make in the world class
|
||||
for worker in Script.world.workers:
|
||||
if worker.state == State.WORKING:
|
||||
return Script.world.__str__(), worker_status
|
||||
|
||||
return 'No active jobs!', worker_status
|
||||
|
||||
# init system if it isn't already
|
||||
except AttributeError:
|
||||
# batch size will be clobbered later once an actual request is made anyway so I just pass 1
|
||||
except AttributeError as e:
|
||||
print(e)
|
||||
# batch size will be clobbered later once an actual request is made anyway
|
||||
Script.initialize(initial_payload=None)
|
||||
Script.ui_connect_test()
|
||||
return 'refresh!', 'refresh!'
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -222,33 +242,38 @@ class Script(scripts.Script):
|
|||
|
||||
@staticmethod
|
||||
def initialize(initial_payload):
|
||||
if Script.verify_remotes is False:
|
||||
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
# get default batch size
|
||||
try:
|
||||
batch_size = initial_payload.batch_size
|
||||
except AttributeError:
|
||||
batch_size = 1
|
||||
|
||||
Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes)
|
||||
if Script.world is None:
|
||||
if Script.verify_remotes is False:
|
||||
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
# add workers to the world
|
||||
for worker in cmd_opts.distributed_remotes:
|
||||
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
|
||||
# construct World
|
||||
Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes)
|
||||
|
||||
try:
|
||||
Script.world.initialize(batch_size)
|
||||
print("World initialized!")
|
||||
except WorldAlreadyInitialized:
|
||||
Script.world.update_world(total_batch_size=batch_size)
|
||||
# add workers to the world
|
||||
for worker in cmd_opts.distributed_remotes:
|
||||
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
|
||||
|
||||
else:
|
||||
# update world or initialize and update if necessary
|
||||
try:
|
||||
Script.world.initialize(batch_size)
|
||||
print("World initialized!")
|
||||
except WorldAlreadyInitialized:
|
||||
Script.world.update_world(total_batch_size=batch_size)
|
||||
|
||||
def run(self, p, *args):
|
||||
if cmd_opts.distributed_remotes is None:
|
||||
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
|
||||
|
||||
# register gallery callback
|
||||
script_callbacks.on_after_image_processed(Script.add_to_gallery)
|
||||
script_callbacks.on_after_batch_processed(Script.add_to_gallery)
|
||||
|
||||
Script.initialize(initial_payload=p)
|
||||
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ class Worker:
|
|||
These things are used to draw certain conclusions after the first session.
|
||||
|
||||
Args:
|
||||
benchmark_payload (dict): The payload used the benchmark.
|
||||
benchmark_payload (dict): The payload used in the benchmark.
|
||||
|
||||
Returns:
|
||||
dict: Worker info, including how it was benchmarked.
|
||||
|
|
@ -137,9 +137,18 @@ class Worker:
|
|||
d[self.uuid] = data
|
||||
return d
|
||||
|
||||
def treat_mpe(self):
|
||||
"""
|
||||
In collecting percent errors to calculate the MPE, there may be influential outliers that skew the results.
|
||||
Here we cull those outliers from them in order to get a more accurate end result.
|
||||
"""
|
||||
|
||||
# TODO implement this
|
||||
pass
|
||||
|
||||
def eta_mpe(self):
|
||||
"""
|
||||
Returns the mean absolute percent error using all the currently stored eta percent errors.
|
||||
Returns the mean percent error using all the currently stored eta percent errors.
|
||||
|
||||
Returns:
|
||||
mpe (float): The mean percent error of a worker's calculation estimates.
|
||||
|
|
@ -165,7 +174,7 @@ class Worker:
|
|||
"""
|
||||
|
||||
# TODO check if using http or https
|
||||
return f"https://{self.__str__()}/sdapi/v1/{route}"
|
||||
return f"http://{self.__str__()}/sdapi/v1/{route}"
|
||||
|
||||
def batch_eta_hr(self, payload: dict) -> float:
|
||||
"""
|
||||
|
|
@ -233,13 +242,21 @@ class Worker:
|
|||
if len(self.eta_percent_error) > 0:
|
||||
correction = eta * (self.eta_mpe() / 100)
|
||||
|
||||
if cmd_opts.distributed_debug:
|
||||
print(f"worker '{self.uuid}'s last ETA was off by {correction}%")
|
||||
# if abs(correction) > 300:
|
||||
# print(f"correction {abs(correction)} exceeds 300%... .")
|
||||
|
||||
if cmd_opts.distributed_debug:
|
||||
print(f"worker '{self.uuid}'s last ETA was off by {correction:.2f}%")
|
||||
print(f"{self.uuid} eta before correction: ", eta)
|
||||
|
||||
# do regression
|
||||
if correction > 0:
|
||||
eta += correction
|
||||
else:
|
||||
eta -= correction
|
||||
else:
|
||||
eta += correction
|
||||
|
||||
if cmd_opts.distributed_debug:
|
||||
print(f"{self.uuid} eta after correction: ", eta)
|
||||
|
||||
return eta
|
||||
except Exception as e:
|
||||
|
|
@ -308,10 +325,18 @@ class Worker:
|
|||
print(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n")
|
||||
print(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
|
||||
|
||||
if self.eta_percent_error == 0:
|
||||
self.eta_percent_error[0] = variance
|
||||
# if the variance is greater than 500% then we ignore it to prevent variation inflation
|
||||
if abs(variance) < 500:
|
||||
# check if there are already 5 samples and if so, remove the oldest
|
||||
# this should help adjust to the user changing tasks
|
||||
if len(self.eta_percent_error) > 4:
|
||||
self.eta_percent_error.pop(0)
|
||||
if self.eta_percent_error == 0: # init
|
||||
self.eta_percent_error[0] = variance
|
||||
else: # normal case
|
||||
self.eta_percent_error.append(variance)
|
||||
else:
|
||||
self.eta_percent_error.append(variance)
|
||||
print(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
|
||||
|
||||
except Exception as e:
|
||||
if payload['batch_size'] == 0:
|
||||
|
|
|
|||
|
|
@ -51,6 +51,14 @@ class Job:
|
|||
self.batch_size: int = batch_size
|
||||
self.complementary: bool = False
|
||||
|
||||
def __str__(self):
|
||||
prefix = ''
|
||||
suffix = f"Job: {self.batch_size} images. Owned by '{self.worker.uuid}'. Rate: {self.worker.avg_ipm}ipm"
|
||||
if self.complementary:
|
||||
prefix = "(complementary) "
|
||||
|
||||
return prefix + suffix
|
||||
|
||||
|
||||
class World:
|
||||
"""
|
||||
|
|
@ -70,7 +78,7 @@ class World:
|
|||
self.total_batch_size: int = 0
|
||||
self.workers: List[Worker] = [master_worker]
|
||||
self.jobs: List[Job] = []
|
||||
self.job_timeout: int = 10 # seconds
|
||||
self.job_timeout: int = 0 # seconds
|
||||
self.initialized: bool = False
|
||||
self.verify_remotes = verify_remotes
|
||||
self.initial_payload = copy.copy(initial_payload)
|
||||
|
|
@ -188,15 +196,19 @@ class World:
|
|||
t = Thread(target=worker.refresh_checkpoints, args=())
|
||||
t.start()
|
||||
|
||||
def benchmark(self):
|
||||
def benchmark(self, rebenchmark: bool = False):
|
||||
"""
|
||||
Attempts to benchmark all workers a part of the world.
|
||||
"""
|
||||
global benchmark_payload
|
||||
|
||||
workers_info: dict = {}
|
||||
saved: bool = os.path.exists(self.worker_info_path)
|
||||
benchmark_payload_loaded: bool = False
|
||||
|
||||
if rebenchmark:
|
||||
saved = False
|
||||
|
||||
if saved:
|
||||
workers_info = json.load(open(self.worker_info_path, 'r'))
|
||||
|
||||
|
|
@ -250,6 +262,14 @@ class World:
|
|||
print(f"{i}. worker '{worker}' - {worker.avg_ipm} ipm")
|
||||
i += 1
|
||||
|
||||
def __str__(self):
|
||||
# print status of all jobs
|
||||
jobs_str = ""
|
||||
for job in self.jobs:
|
||||
jobs_str += job.__str__() + "\n"
|
||||
|
||||
return jobs_str
|
||||
|
||||
def realtime_jobs(self) -> List[Job]:
|
||||
"""
|
||||
Determines which jobs are considered real-time by checking which jobs are not(complementary).
|
||||
|
|
@ -263,6 +283,8 @@ class World:
|
|||
if job.complementary is False:
|
||||
fast_jobs.append(job)
|
||||
|
||||
print(f"fast jobs: {fast_jobs}")
|
||||
|
||||
return fast_jobs
|
||||
|
||||
def slowest_realtime_job(self) -> Job:
|
||||
|
|
@ -292,6 +314,10 @@ class World:
|
|||
"""
|
||||
|
||||
fastest_worker = self.fastest_realtime_job().worker
|
||||
# if the worker is the fastest, then there is no lag
|
||||
if worker == fastest_worker:
|
||||
return 0
|
||||
|
||||
lag = worker.batch_eta(payload=payload) - fastest_worker.batch_eta(payload=payload)
|
||||
|
||||
return lag
|
||||
|
|
@ -304,6 +330,7 @@ class World:
|
|||
Returns:
|
||||
float: Local worker speed in ipm
|
||||
"""
|
||||
global benchmark_payload
|
||||
|
||||
master_bench_payload = copy.copy(self.initial_payload)
|
||||
|
||||
|
|
@ -357,14 +384,14 @@ class World:
|
|||
|
||||
deferred_images = 0 # the number of images that were not assigned to a worker due to the worker being too slow
|
||||
# the maximum amount of images that a "slow" worker can produce in the slack space where other nodes are working
|
||||
max_compensation = 4
|
||||
# max_compensation = 4 currently unused
|
||||
images_per_job = None
|
||||
|
||||
for job in self.jobs:
|
||||
|
||||
lag = self.job_stall(job.worker, payload=payload)
|
||||
|
||||
if lag < self.job_timeout:
|
||||
if lag < self.job_timeout or lag == 0:
|
||||
job.batch_size = payload['batch_size']
|
||||
continue
|
||||
|
||||
|
|
@ -396,6 +423,9 @@ class World:
|
|||
|
||||
slowest_active_worker = self.slowest_realtime_job().worker
|
||||
slack_time = slowest_active_worker.batch_eta(payload=payload)
|
||||
if cmd_opts.distributed_debug:
|
||||
print(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.uuid}'")
|
||||
|
||||
# in the case that this worker is now taking on what others workers would have been (if they were real-time)
|
||||
# this means that there will be more slack time for complementary nodes
|
||||
slack_time = slack_time + ((slack_time / payload['batch_size']) * images_per_job)
|
||||
|
|
@ -412,6 +442,8 @@ class World:
|
|||
# It might be better to just inject a black image. (if master is that slow)
|
||||
master_job = self.master_job()
|
||||
if master_job.batch_size < 1:
|
||||
if cmd_opts.distributed_debug:
|
||||
print("Master couldn't keep up... defaulting to 1 image")
|
||||
master_job.batch_size = 1
|
||||
|
||||
print("After job optimization, job layout is the following:")
|
||||
|
|
|
|||
Loading…
Reference in New Issue