Fix redundantly init'ing World. Fix typo. Fix a couple vars not being in scope. Add some internal __str__ overrides for debugging. Improve regression by limiting the number of samples by 5. Filter some influential outliers by checking if the variance is astronomically high.
parent
6ffef54e36
commit
b4c24d75d8
|
|
@ -23,4 +23,3 @@ def preload(parser):
|
||||||
help="Enable debug information",
|
help="Enable debug information",
|
||||||
action="store_true"
|
action="store_true"
|
||||||
)
|
)
|
||||||
# args = parser.parse_args()
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from pathlib import Path
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
|
from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
|
||||||
from scripts.spartan.Worker import Worker
|
from scripts.spartan.Worker import Worker, State
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
# TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
|
# TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
|
||||||
|
|
@ -60,13 +60,16 @@ class Script(scripts.Script):
|
||||||
|
|
||||||
with gradio.Tab('Status') as status_tab:
|
with gradio.Tab('Status') as status_tab:
|
||||||
status = gradio.Textbox(elem_id='status', show_label=False)
|
status = gradio.Textbox(elem_id='status', show_label=False)
|
||||||
status_tab.select(fn=Script.ui_connect_test, inputs=[], outputs=[status])
|
status_tab.select(fn=Script.ui_connect_status, inputs=[], outputs=[status])
|
||||||
|
|
||||||
|
jobs = gradio.Textbox(elem_id='jobs', label='Jobs', show_label=True)
|
||||||
|
# status_tab.select(fn=Script.world.__str__, inputs=[], outputs=[jobs, status]),
|
||||||
|
|
||||||
refresh_status_btn = gradio.Button(value='Refresh')
|
refresh_status_btn = gradio.Button(value='Refresh')
|
||||||
refresh_status_btn.style(size='sm')
|
refresh_status_btn.style(size='sm')
|
||||||
refresh_status_btn.click(Script.ui_connect_test, inputs=[], outputs=[status])
|
refresh_status_btn.click(Script.ui_connect_status, inputs=[], outputs=[jobs, status])
|
||||||
|
|
||||||
with gradio.Tab('Remote Utils'):
|
with gradio.Tab('Utils'):
|
||||||
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
|
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
|
||||||
refresh_checkpoints_btn.style(full_width=False)
|
refresh_checkpoints_btn.style(full_width=False)
|
||||||
refresh_checkpoints_btn.click(Script.ui_connect_refresh_ckpts_btn, inputs=[], outputs=[])
|
refresh_checkpoints_btn.click(Script.ui_connect_refresh_ckpts_btn, inputs=[], outputs=[])
|
||||||
|
|
@ -79,8 +82,19 @@ class Script(scripts.Script):
|
||||||
interrupt_all_btn.style(full_width=False)
|
interrupt_all_btn.style(full_width=False)
|
||||||
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
|
interrupt_all_btn.click(Script.ui_connect_interrupt_btn, inputs=[], outputs=[])
|
||||||
|
|
||||||
|
# redo benchmarks button
|
||||||
|
redo_benchmarks_btn = gradio.Button(value='Redo benchmarks')
|
||||||
|
redo_benchmarks_btn.style(full_width=False)
|
||||||
|
redo_benchmarks_btn.click(Script.ui_connect_benchmark_button, inputs=[], outputs=[])
|
||||||
|
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ui_connect_benchmark_button():
|
||||||
|
print("Redoing benchmarks...")
|
||||||
|
Script.world.benchmark(rebenchmark=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def user_sync_script():
|
def user_sync_script():
|
||||||
user_scripts = Path(os.path.abspath(__file__)).parent.joinpath('user')
|
user_scripts = Path(os.path.abspath(__file__)).parent.joinpath('user')
|
||||||
|
|
@ -120,23 +134,29 @@ class Script(scripts.Script):
|
||||||
print("Distributed system not initialized")
|
print("Distributed system not initialized")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def ui_connect_test():
|
def ui_connect_status():
|
||||||
try:
|
try:
|
||||||
temp = ''
|
worker_status = ''
|
||||||
|
|
||||||
for worker in Script.world.workers:
|
for worker in Script.world.workers:
|
||||||
if worker.master:
|
if worker.master:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
temp += f"{worker.uuid} at {worker.address} is {worker.state.name}\n"
|
worker_status += f"{worker.uuid} at {worker.address} is {worker.state.name}\n"
|
||||||
|
|
||||||
return temp
|
# TODO replace this with a single check to a state flag that we should make in the world class
|
||||||
|
for worker in Script.world.workers:
|
||||||
|
if worker.state == State.WORKING:
|
||||||
|
return Script.world.__str__(), worker_status
|
||||||
|
|
||||||
|
return 'No active jobs!', worker_status
|
||||||
|
|
||||||
# init system if it isn't already
|
# init system if it isn't already
|
||||||
except AttributeError:
|
except AttributeError as e:
|
||||||
# batch size will be clobbered later once an actual request is made anyway so I just pass 1
|
print(e)
|
||||||
|
# batch size will be clobbered later once an actual request is made anyway
|
||||||
Script.initialize(initial_payload=None)
|
Script.initialize(initial_payload=None)
|
||||||
Script.ui_connect_test()
|
return 'refresh!', 'refresh!'
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -222,33 +242,38 @@ class Script(scripts.Script):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def initialize(initial_payload):
|
def initialize(initial_payload):
|
||||||
if Script.verify_remotes is False:
|
# get default batch size
|
||||||
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
|
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
batch_size = initial_payload.batch_size
|
batch_size = initial_payload.batch_size
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
|
||||||
Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes)
|
if Script.world is None:
|
||||||
|
if Script.verify_remotes is False:
|
||||||
|
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
|
||||||
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
# add workers to the world
|
# construct World
|
||||||
for worker in cmd_opts.distributed_remotes:
|
Script.world = World(initial_payload=initial_payload, verify_remotes=Script.verify_remotes)
|
||||||
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
|
|
||||||
|
|
||||||
try:
|
# add workers to the world
|
||||||
Script.world.initialize(batch_size)
|
for worker in cmd_opts.distributed_remotes:
|
||||||
print("World initialized!")
|
Script.world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
|
||||||
except WorldAlreadyInitialized:
|
|
||||||
Script.world.update_world(total_batch_size=batch_size)
|
else:
|
||||||
|
# update world or initialize and update if necessary
|
||||||
|
try:
|
||||||
|
Script.world.initialize(batch_size)
|
||||||
|
print("World initialized!")
|
||||||
|
except WorldAlreadyInitialized:
|
||||||
|
Script.world.update_world(total_batch_size=batch_size)
|
||||||
|
|
||||||
def run(self, p, *args):
|
def run(self, p, *args):
|
||||||
if cmd_opts.distributed_remotes is None:
|
if cmd_opts.distributed_remotes is None:
|
||||||
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
|
raise RuntimeError("Distributed - No remotes passed. (Try using `--distributed-remotes`?)")
|
||||||
|
|
||||||
# register gallery callback
|
# register gallery callback
|
||||||
script_callbacks.on_after_image_processed(Script.add_to_gallery)
|
script_callbacks.on_after_batch_processed(Script.add_to_gallery)
|
||||||
|
|
||||||
Script.initialize(initial_payload=p)
|
Script.initialize(initial_payload=p)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,7 @@ class Worker:
|
||||||
These things are used to draw certain conclusions after the first session.
|
These things are used to draw certain conclusions after the first session.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
benchmark_payload (dict): The payload used the benchmark.
|
benchmark_payload (dict): The payload used in the benchmark.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Worker info, including how it was benchmarked.
|
dict: Worker info, including how it was benchmarked.
|
||||||
|
|
@ -137,9 +137,18 @@ class Worker:
|
||||||
d[self.uuid] = data
|
d[self.uuid] = data
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
def treat_mpe(self):
|
||||||
|
"""
|
||||||
|
In collecting percent errors to calculate the MPE, there may be influential outliers that skew the results.
|
||||||
|
Here we cull those outliers from them in order to get a more accurate end result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO implement this
|
||||||
|
pass
|
||||||
|
|
||||||
def eta_mpe(self):
|
def eta_mpe(self):
|
||||||
"""
|
"""
|
||||||
Returns the mean absolute percent error using all the currently stored eta percent errors.
|
Returns the mean percent error using all the currently stored eta percent errors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
mpe (float): The mean percent error of a worker's calculation estimates.
|
mpe (float): The mean percent error of a worker's calculation estimates.
|
||||||
|
|
@ -165,7 +174,7 @@ class Worker:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO check if using http or https
|
# TODO check if using http or https
|
||||||
return f"https://{self.__str__()}/sdapi/v1/{route}"
|
return f"http://{self.__str__()}/sdapi/v1/{route}"
|
||||||
|
|
||||||
def batch_eta_hr(self, payload: dict) -> float:
|
def batch_eta_hr(self, payload: dict) -> float:
|
||||||
"""
|
"""
|
||||||
|
|
@ -233,13 +242,21 @@ class Worker:
|
||||||
if len(self.eta_percent_error) > 0:
|
if len(self.eta_percent_error) > 0:
|
||||||
correction = eta * (self.eta_mpe() / 100)
|
correction = eta * (self.eta_mpe() / 100)
|
||||||
|
|
||||||
if cmd_opts.distributed_debug:
|
# if abs(correction) > 300:
|
||||||
print(f"worker '{self.uuid}'s last ETA was off by {correction}%")
|
# print(f"correction {abs(correction)} exceeds 300%... .")
|
||||||
|
|
||||||
|
if cmd_opts.distributed_debug:
|
||||||
|
print(f"worker '{self.uuid}'s last ETA was off by {correction:.2f}%")
|
||||||
|
print(f"{self.uuid} eta before correction: ", eta)
|
||||||
|
|
||||||
|
# do regression
|
||||||
if correction > 0:
|
if correction > 0:
|
||||||
eta += correction
|
|
||||||
else:
|
|
||||||
eta -= correction
|
eta -= correction
|
||||||
|
else:
|
||||||
|
eta += correction
|
||||||
|
|
||||||
|
if cmd_opts.distributed_debug:
|
||||||
|
print(f"{self.uuid} eta after correction: ", eta)
|
||||||
|
|
||||||
return eta
|
return eta
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -308,10 +325,18 @@ class Worker:
|
||||||
print(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n")
|
print(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n")
|
||||||
print(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
|
print(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
|
||||||
|
|
||||||
if self.eta_percent_error == 0:
|
# if the variance is greater than 500% then we ignore it to prevent variation inflation
|
||||||
self.eta_percent_error[0] = variance
|
if abs(variance) < 500:
|
||||||
|
# check if there are already 5 samples and if so, remove the oldest
|
||||||
|
# this should help adjust to the user changing tasks
|
||||||
|
if len(self.eta_percent_error) > 4:
|
||||||
|
self.eta_percent_error.pop(0)
|
||||||
|
if self.eta_percent_error == 0: # init
|
||||||
|
self.eta_percent_error[0] = variance
|
||||||
|
else: # normal case
|
||||||
|
self.eta_percent_error.append(variance)
|
||||||
else:
|
else:
|
||||||
self.eta_percent_error.append(variance)
|
print(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if payload['batch_size'] == 0:
|
if payload['batch_size'] == 0:
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,14 @@ class Job:
|
||||||
self.batch_size: int = batch_size
|
self.batch_size: int = batch_size
|
||||||
self.complementary: bool = False
|
self.complementary: bool = False
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
prefix = ''
|
||||||
|
suffix = f"Job: {self.batch_size} images. Owned by '{self.worker.uuid}'. Rate: {self.worker.avg_ipm}ipm"
|
||||||
|
if self.complementary:
|
||||||
|
prefix = "(complementary) "
|
||||||
|
|
||||||
|
return prefix + suffix
|
||||||
|
|
||||||
|
|
||||||
class World:
|
class World:
|
||||||
"""
|
"""
|
||||||
|
|
@ -70,7 +78,7 @@ class World:
|
||||||
self.total_batch_size: int = 0
|
self.total_batch_size: int = 0
|
||||||
self.workers: List[Worker] = [master_worker]
|
self.workers: List[Worker] = [master_worker]
|
||||||
self.jobs: List[Job] = []
|
self.jobs: List[Job] = []
|
||||||
self.job_timeout: int = 10 # seconds
|
self.job_timeout: int = 0 # seconds
|
||||||
self.initialized: bool = False
|
self.initialized: bool = False
|
||||||
self.verify_remotes = verify_remotes
|
self.verify_remotes = verify_remotes
|
||||||
self.initial_payload = copy.copy(initial_payload)
|
self.initial_payload = copy.copy(initial_payload)
|
||||||
|
|
@ -188,15 +196,19 @@ class World:
|
||||||
t = Thread(target=worker.refresh_checkpoints, args=())
|
t = Thread(target=worker.refresh_checkpoints, args=())
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
def benchmark(self):
|
def benchmark(self, rebenchmark: bool = False):
|
||||||
"""
|
"""
|
||||||
Attempts to benchmark all workers a part of the world.
|
Attempts to benchmark all workers a part of the world.
|
||||||
"""
|
"""
|
||||||
|
global benchmark_payload
|
||||||
|
|
||||||
workers_info: dict = {}
|
workers_info: dict = {}
|
||||||
saved: bool = os.path.exists(self.worker_info_path)
|
saved: bool = os.path.exists(self.worker_info_path)
|
||||||
benchmark_payload_loaded: bool = False
|
benchmark_payload_loaded: bool = False
|
||||||
|
|
||||||
|
if rebenchmark:
|
||||||
|
saved = False
|
||||||
|
|
||||||
if saved:
|
if saved:
|
||||||
workers_info = json.load(open(self.worker_info_path, 'r'))
|
workers_info = json.load(open(self.worker_info_path, 'r'))
|
||||||
|
|
||||||
|
|
@ -250,6 +262,14 @@ class World:
|
||||||
print(f"{i}. worker '{worker}' - {worker.avg_ipm} ipm")
|
print(f"{i}. worker '{worker}' - {worker.avg_ipm} ipm")
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
# print status of all jobs
|
||||||
|
jobs_str = ""
|
||||||
|
for job in self.jobs:
|
||||||
|
jobs_str += job.__str__() + "\n"
|
||||||
|
|
||||||
|
return jobs_str
|
||||||
|
|
||||||
def realtime_jobs(self) -> List[Job]:
|
def realtime_jobs(self) -> List[Job]:
|
||||||
"""
|
"""
|
||||||
Determines which jobs are considered real-time by checking which jobs are not(complementary).
|
Determines which jobs are considered real-time by checking which jobs are not(complementary).
|
||||||
|
|
@ -263,6 +283,8 @@ class World:
|
||||||
if job.complementary is False:
|
if job.complementary is False:
|
||||||
fast_jobs.append(job)
|
fast_jobs.append(job)
|
||||||
|
|
||||||
|
print(f"fast jobs: {fast_jobs}")
|
||||||
|
|
||||||
return fast_jobs
|
return fast_jobs
|
||||||
|
|
||||||
def slowest_realtime_job(self) -> Job:
|
def slowest_realtime_job(self) -> Job:
|
||||||
|
|
@ -292,6 +314,10 @@ class World:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
fastest_worker = self.fastest_realtime_job().worker
|
fastest_worker = self.fastest_realtime_job().worker
|
||||||
|
# if the worker is the fastest, then there is no lag
|
||||||
|
if worker == fastest_worker:
|
||||||
|
return 0
|
||||||
|
|
||||||
lag = worker.batch_eta(payload=payload) - fastest_worker.batch_eta(payload=payload)
|
lag = worker.batch_eta(payload=payload) - fastest_worker.batch_eta(payload=payload)
|
||||||
|
|
||||||
return lag
|
return lag
|
||||||
|
|
@ -304,6 +330,7 @@ class World:
|
||||||
Returns:
|
Returns:
|
||||||
float: Local worker speed in ipm
|
float: Local worker speed in ipm
|
||||||
"""
|
"""
|
||||||
|
global benchmark_payload
|
||||||
|
|
||||||
master_bench_payload = copy.copy(self.initial_payload)
|
master_bench_payload = copy.copy(self.initial_payload)
|
||||||
|
|
||||||
|
|
@ -357,14 +384,14 @@ class World:
|
||||||
|
|
||||||
deferred_images = 0 # the number of images that were not assigned to a worker due to the worker being too slow
|
deferred_images = 0 # the number of images that were not assigned to a worker due to the worker being too slow
|
||||||
# the maximum amount of images that a "slow" worker can produce in the slack space where other nodes are working
|
# the maximum amount of images that a "slow" worker can produce in the slack space where other nodes are working
|
||||||
max_compensation = 4
|
# max_compensation = 4 currently unused
|
||||||
images_per_job = None
|
images_per_job = None
|
||||||
|
|
||||||
for job in self.jobs:
|
for job in self.jobs:
|
||||||
|
|
||||||
lag = self.job_stall(job.worker, payload=payload)
|
lag = self.job_stall(job.worker, payload=payload)
|
||||||
|
|
||||||
if lag < self.job_timeout:
|
if lag < self.job_timeout or lag == 0:
|
||||||
job.batch_size = payload['batch_size']
|
job.batch_size = payload['batch_size']
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -396,6 +423,9 @@ class World:
|
||||||
|
|
||||||
slowest_active_worker = self.slowest_realtime_job().worker
|
slowest_active_worker = self.slowest_realtime_job().worker
|
||||||
slack_time = slowest_active_worker.batch_eta(payload=payload)
|
slack_time = slowest_active_worker.batch_eta(payload=payload)
|
||||||
|
if cmd_opts.distributed_debug:
|
||||||
|
print(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.uuid}'")
|
||||||
|
|
||||||
# in the case that this worker is now taking on what others workers would have been (if they were real-time)
|
# in the case that this worker is now taking on what others workers would have been (if they were real-time)
|
||||||
# this means that there will be more slack time for complementary nodes
|
# this means that there will be more slack time for complementary nodes
|
||||||
slack_time = slack_time + ((slack_time / payload['batch_size']) * images_per_job)
|
slack_time = slack_time + ((slack_time / payload['batch_size']) * images_per_job)
|
||||||
|
|
@ -412,6 +442,8 @@ class World:
|
||||||
# It might be better to just inject a black image. (if master is that slow)
|
# It might be better to just inject a black image. (if master is that slow)
|
||||||
master_job = self.master_job()
|
master_job = self.master_job()
|
||||||
if master_job.batch_size < 1:
|
if master_job.batch_size < 1:
|
||||||
|
if cmd_opts.distributed_debug:
|
||||||
|
print("Master couldn't keep up... defaulting to 1 image")
|
||||||
master_job.batch_size = 1
|
master_job.batch_size = 1
|
||||||
|
|
||||||
print("After job optimization, job layout is the following:")
|
print("After job optimization, job layout is the following:")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue