add 'thin-client' mode
parent
6a28587796
commit
887e58608b
|
|
@ -21,7 +21,7 @@ from scripts.spartan.UI import UI
|
|||
from modules.shared import opts
|
||||
from scripts.spartan.shared import logger
|
||||
from scripts.spartan.control_net import pack_control_net
|
||||
from modules.processing import fix_seed
|
||||
from modules.processing import fix_seed, Processed
|
||||
|
||||
|
||||
# TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
|
||||
|
|
@ -67,32 +67,47 @@ class Script(scripts.Script):
|
|||
def processed_inject_image(image, info_index, iteration: int, save_path_override=None, grid=False, response=None):
|
||||
image_params: json = response["parameters"]
|
||||
image_info_post: json = json.loads(response["info"]) # image info known after processing
|
||||
num_response_images = image_params["batch_size"] * image_params["n_iter"]
|
||||
|
||||
seed = None
|
||||
subseed = None
|
||||
negative_prompt = None
|
||||
|
||||
|
||||
try:
|
||||
# some metadata
|
||||
processed.all_seeds.append(image_info_post["all_seeds"][info_index])
|
||||
processed.all_subseeds.append(image_info_post["all_subseeds"][info_index])
|
||||
processed.all_negative_prompts.append(image_info_post["all_negative_prompts"][info_index])
|
||||
if num_response_images > 1:
|
||||
seed = image_info_post['all_seeds'][info_index]
|
||||
subseed = image_info_post['all_subseeds'][info_index]
|
||||
negative_prompt = image_info_post['all_negative_prompts'][info_index]
|
||||
else:
|
||||
seed = image_info_post['seed']
|
||||
subseed = image_info_post['subseed']
|
||||
negative_prompt = image_info_post['negative_prompt']
|
||||
except Exception:
|
||||
# like with controlnet masks, there isn't always full post-gen info, so we use the first images'
|
||||
logger.debug(f"Image at index {i} for '{worker.uuid}' was missing some post-generation data")
|
||||
processed_inject_image(image=image, info_index=0, iteration=iteration)
|
||||
return
|
||||
|
||||
processed.all_seeds.append(seed)
|
||||
processed.all_subseeds.append(subseed)
|
||||
processed.all_negative_prompts.append(negative_prompt)
|
||||
processed.all_prompts.append(image_params["prompt"])
|
||||
processed.images.append(image) # actual received image
|
||||
|
||||
# generate info-text string
|
||||
# modules.ui_common -> update_generation_info renders to html below gallery
|
||||
images_per_batch = p.n_iter * p.batch_size
|
||||
# zero-indexed position of image in total batch (so including master results)
|
||||
true_image_pos = len(processed.images) - 1
|
||||
num_remote_images = images_per_batch * p.batch_size
|
||||
if p.n_iter > 1: # if splitting by batch count
|
||||
num_remote_images *= p.n_iter - 1
|
||||
info_text_used_seed_index = info_index + p.n_iter * p.batch_size if not grid else 0
|
||||
|
||||
if iteration != 0:
|
||||
logger.debug(f"iteration {iteration}/{p.n_iter}, image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, info-index: {info_index}, used seed index {info_text_used_seed_index}")
|
||||
logger.debug(f"iteration {iteration}/{p.n_iter}, image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, info-index: {info_index}")
|
||||
|
||||
if Script.world.thin_client_mode:
|
||||
p.all_negative_prompts = processed.all_negative_prompts
|
||||
|
||||
info_text = processing.create_infotext(
|
||||
p=p,
|
||||
|
|
@ -239,7 +254,7 @@ class Script(scripts.Script):
|
|||
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works
|
||||
# see test/basic_features/txt2img_test.py for an example
|
||||
payload = copy.copy(p.__dict__)
|
||||
payload['batch_size'] = Script.world.get_default_worker_batch_size()
|
||||
payload['batch_size'] = Script.world.default_batch_size()
|
||||
payload['scripts'] = None
|
||||
del payload['script_args']
|
||||
|
||||
|
|
@ -295,12 +310,23 @@ class Script(scripts.Script):
|
|||
started_jobs.append(job)
|
||||
|
||||
# if master batch size was changed again due to optimization change it to the updated value
|
||||
p.batch_size = Script.world.get_master_batch_size()
|
||||
if not self.world.thin_client_mode:
|
||||
p.batch_size = Script.world.master_job().batch_size
|
||||
Script.master_start = time.time()
|
||||
|
||||
# generate images assigned to local machine
|
||||
p.do_not_save_grid = True # don't generate grid from master as we are doing this later.
|
||||
processed = processing.process_images(p, *args)
|
||||
Script.add_to_gallery(processed, p)
|
||||
if Script.world.thin_client_mode:
|
||||
p.batch_size = 0
|
||||
processed = Processed(p=p, images_list=[])
|
||||
processed.all_prompts = []
|
||||
processed.all_seeds = []
|
||||
processed.all_subseeds = []
|
||||
processed.all_negative_prompts = []
|
||||
processed.infotexts = []
|
||||
processed.prompt = None
|
||||
else:
|
||||
processed = processing.process_images(p, *args)
|
||||
|
||||
Script.add_to_gallery(processed, p)
|
||||
return processed
|
||||
|
|
|
|||
|
|
@ -63,6 +63,10 @@ class UI:
|
|||
|
||||
return 'No active jobs!', worker_status
|
||||
|
||||
def save_btn(self, thin_client_mode):
|
||||
self.world.thin_client_mode = thin_client_mode
|
||||
logger.debug(f"thin client mode is now {thin_client_mode}")
|
||||
|
||||
# end handlers
|
||||
|
||||
def create_root(self):
|
||||
|
|
@ -97,4 +101,14 @@ class UI:
|
|||
redo_benchmarks_btn.style(full_width=False)
|
||||
redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[])
|
||||
|
||||
with gradio.Tab('Settings'):
|
||||
thin_client_cbx = gradio.Checkbox(
|
||||
label='Thin-client mode (experimental)',
|
||||
info="Only generate images using remote workers. There will be no previews when enabled.",
|
||||
value=self.world.thin_client_mode
|
||||
)
|
||||
|
||||
save_btn = gradio.Button(value='Save')
|
||||
save_btn.click(fn=self.save_btn, inputs=[thin_client_cbx])
|
||||
|
||||
return root
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class Worker:
|
|||
|
||||
address: str = None
|
||||
port: int = None
|
||||
avg_ipm: int = None
|
||||
avg_ipm: float = None
|
||||
uuid: str = None
|
||||
queried: bool = False # whether this worker has been connected to yet
|
||||
free_vram: bytes = 0
|
||||
|
|
@ -98,7 +98,7 @@ class Worker:
|
|||
self.master = master
|
||||
self.uuid = 'master'
|
||||
# set to a sentinel value to avoid issues with speed comparisons
|
||||
self.avg_ipm = 0
|
||||
# self.avg_ipm = 0
|
||||
|
||||
# right now this is really only for clarity while debugging:
|
||||
self.address = server_name
|
||||
|
|
@ -206,54 +206,51 @@ class Worker:
|
|||
num_images = payload['batch_size']
|
||||
|
||||
# if worker has not yet been benchmarked then
|
||||
try:
|
||||
eta = (num_images / self.avg_ipm) * 60
|
||||
# show effect of increased step size
|
||||
real_steps_to_benched = steps / benchmark_payload['steps']
|
||||
eta = eta * real_steps_to_benched
|
||||
eta = (num_images / self.avg_ipm) * 60
|
||||
# show effect of increased step size
|
||||
real_steps_to_benched = steps / benchmark_payload['steps']
|
||||
eta = eta * real_steps_to_benched
|
||||
|
||||
# show effect of high-res fix
|
||||
hr = payload.get('enable_hr', False)
|
||||
if hr:
|
||||
eta += self.batch_eta_hr(payload=payload)
|
||||
# show effect of high-res fix
|
||||
hr = payload.get('enable_hr', False)
|
||||
if hr:
|
||||
eta += self.batch_eta_hr(payload=payload)
|
||||
|
||||
# show effect of image size
|
||||
real_pix_to_benched = (payload['width'] * payload['height'])\
|
||||
/ (benchmark_payload['width'] * benchmark_payload['height'])
|
||||
eta = eta * real_pix_to_benched
|
||||
# show effect of image size
|
||||
real_pix_to_benched = (payload['width'] * payload['height'])\
|
||||
/ (benchmark_payload['width'] * benchmark_payload['height'])
|
||||
eta = eta * real_pix_to_benched
|
||||
|
||||
# show effect of using a sampler other than euler a
|
||||
sampler = payload.get('sampler_name', 'Euler a')
|
||||
if sampler != 'Euler a':
|
||||
try:
|
||||
percent_difference = self.other_to_euler_a[payload['sampler_name']]
|
||||
if percent_difference > 0:
|
||||
eta -= (eta * abs((percent_difference / 100)))
|
||||
else:
|
||||
eta += (eta * abs((percent_difference / 100)))
|
||||
except KeyError:
|
||||
logger.warning(f"Sampler '{payload['sampler_name']}' efficiency is not recorded.\n")
|
||||
# in this case the sampler will be treated as having the same efficiency as Euler a
|
||||
# show effect of using a sampler other than euler a
|
||||
sampler = payload.get('sampler_name', 'Euler a')
|
||||
if sampler != 'Euler a':
|
||||
try:
|
||||
percent_difference = self.other_to_euler_a[payload['sampler_name']]
|
||||
if percent_difference > 0:
|
||||
eta -= (eta * abs((percent_difference / 100)))
|
||||
else:
|
||||
eta += (eta * abs((percent_difference / 100)))
|
||||
except KeyError:
|
||||
logger.warning(f"Sampler '{payload['sampler_name']}' efficiency is not recorded.\n")
|
||||
# in this case the sampler will be treated as having the same efficiency as Euler a
|
||||
|
||||
# TODO save and load each workers MPE before the end of session to workers.json.
|
||||
# That way initial estimations are more accurate from the second sdwui session onward
|
||||
# adjust for a known inaccuracy in our estimation of this worker using average percent error
|
||||
if len(self.eta_percent_error) > 0:
|
||||
correction = eta * (self.eta_mpe() / 100)
|
||||
# TODO save and load each workers MPE before the end of session to workers.json.
|
||||
# That way initial estimations are more accurate from the second sdwui session onward
|
||||
# adjust for a known inaccuracy in our estimation of this worker using average percent error
|
||||
if len(self.eta_percent_error) > 0:
|
||||
correction = eta * (self.eta_mpe() / 100)
|
||||
|
||||
if not quiet:
|
||||
logger.debug(f"worker '{self.uuid}'s last ETA was off by {correction:.2f}%")
|
||||
correction_summary = f"correcting '{self.uuid}'s ETA: {eta:.2f}s -> "
|
||||
# do regression
|
||||
eta -= correction
|
||||
if not quiet:
|
||||
logger.debug(f"worker '{self.uuid}'s last ETA was off by {correction:.2f}%")
|
||||
correction_summary = f"correcting '{self.uuid}'s ETA: {eta:.2f}s -> "
|
||||
# do regression
|
||||
eta -= correction
|
||||
|
||||
if not quiet:
|
||||
correction_summary += f"{eta:.2f}s"
|
||||
logger.debug(correction_summary)
|
||||
if not quiet:
|
||||
correction_summary += f"{eta:.2f}s"
|
||||
logger.debug(correction_summary)
|
||||
|
||||
return eta
|
||||
except Exception as e:
|
||||
raise e
|
||||
return eta
|
||||
|
||||
def request(self, payload: dict, option_payload: dict, sync_options: bool):
|
||||
"""
|
||||
|
|
@ -295,8 +292,8 @@ class Worker:
|
|||
|
||||
if self.benchmarked:
|
||||
eta = self.batch_eta(payload=payload) * payload['n_iter']
|
||||
logger.debug(f"worker '{self.uuid}' predicts it will take {eta:.3f}s to generate {payload['batch_size']} image("
|
||||
f"s) at a speed of {self.avg_ipm} ipm\n")
|
||||
logger.debug(f"worker '{self.uuid}' predicts it will take {eta:.3f}s to generate {payload['batch_size'] * payload['n_iter']} image("
|
||||
f"s) at a speed of {self.avg_ipm:.2f} ipm\n")
|
||||
|
||||
try:
|
||||
# remove anything that is not serializable
|
||||
|
|
@ -453,7 +450,7 @@ class Worker:
|
|||
ipm_sum = 0
|
||||
for ipm in results:
|
||||
ipm_sum += ipm
|
||||
avg_ipm = math.floor(ipm_sum / samples)
|
||||
avg_ipm = ipm_sum / samples
|
||||
|
||||
logger.debug(f"Worker '{self.uuid}' average ipm: {avg_ipm}")
|
||||
self.avg_ipm = avg_ipm
|
||||
|
|
|
|||
|
|
@ -72,14 +72,15 @@ class World:
|
|||
worker_info_path = this_extension_path.joinpath('workers.json')
|
||||
|
||||
def __init__(self, initial_payload, verify_remotes: bool = True):
|
||||
master_worker = Worker(master=True)
|
||||
self.master_worker = Worker(master=True)
|
||||
self.total_batch_size: int = 0
|
||||
self.__workers: List[Worker] = [master_worker]
|
||||
self.__workers: List[Worker] = [self.master_worker]
|
||||
self.jobs: List[Job] = []
|
||||
self.job_timeout: int = 6 # seconds
|
||||
self.initialized: bool = False
|
||||
self.verify_remotes = verify_remotes
|
||||
self.initial_payload = copy.copy(initial_payload)
|
||||
self.thin_client_mode = False
|
||||
|
||||
def update_world(self, total_batch_size):
|
||||
"""
|
||||
|
|
@ -91,10 +92,7 @@ class World:
|
|||
"""
|
||||
|
||||
self.total_batch_size = total_batch_size
|
||||
|
||||
default_worker_batch_size = self.get_default_worker_batch_size()
|
||||
self.sync_master(batch_size=default_worker_batch_size)
|
||||
self.update_worker_jobs()
|
||||
self.update_jobs()
|
||||
|
||||
def initialize(self, total_batch_size):
|
||||
"""should be called before a world instance is used for anything"""
|
||||
|
|
@ -105,37 +103,19 @@ class World:
|
|||
self.update_world(total_batch_size=total_batch_size)
|
||||
self.initialized = True
|
||||
|
||||
def get_default_worker_batch_size(self) -> int:
|
||||
def default_batch_size(self) -> int:
|
||||
"""the amount of images/total images requested that a worker would compute if conditions were perfect and
|
||||
each worker generated at the same speed"""
|
||||
each worker generated at the same speed. assumes one batch only"""
|
||||
|
||||
return self.total_batch_size // self.world_size()
|
||||
return self.total_batch_size // self.size()
|
||||
|
||||
def world_size(self) -> int:
|
||||
def size(self) -> int:
|
||||
"""
|
||||
Returns:
|
||||
int: The number of nodes currently registered in the world.
|
||||
"""
|
||||
return len(self.get_workers())
|
||||
|
||||
def sync_master(self, batch_size: int):
|
||||
"""
|
||||
update the master node's pseudo-job with <batch_size> of images it will be processing
|
||||
"""
|
||||
|
||||
if len(self.jobs) < 1:
|
||||
master_job = Job(worker=self.master(), batch_size=batch_size)
|
||||
self.jobs.append(master_job)
|
||||
else:
|
||||
self.master_job().batch_size = batch_size
|
||||
|
||||
def get_master_batch_size(self) -> int:
|
||||
"""
|
||||
Returns:
|
||||
int: The number of images the master worker is currently set to generate.
|
||||
"""
|
||||
return self.master_job().batch_size
|
||||
|
||||
def master(self) -> Worker:
|
||||
"""
|
||||
May perform additional checks in the future
|
||||
|
|
@ -143,12 +123,7 @@ class World:
|
|||
Worker: The local/master worker object.
|
||||
"""
|
||||
|
||||
workers = self.get_workers()
|
||||
master = workers[0]
|
||||
if master.master is False:
|
||||
raise RuntimeError("Master should be the first worker in the list")
|
||||
|
||||
return master
|
||||
return self.master_worker
|
||||
|
||||
def master_job(self) -> Job:
|
||||
"""
|
||||
|
|
@ -157,7 +132,9 @@ class World:
|
|||
Job: The local/master worker job object.
|
||||
"""
|
||||
|
||||
return self.jobs[0]
|
||||
for job in self.jobs:
|
||||
if job.worker.master:
|
||||
return job
|
||||
|
||||
def add_worker(self, uuid: str, address: str, port: int):
|
||||
"""
|
||||
|
|
@ -384,36 +361,34 @@ class World:
|
|||
self.master().benchmarked = True
|
||||
return ipm
|
||||
|
||||
def update_worker_jobs(self):
|
||||
def update_jobs(self):
|
||||
"""creates initial jobs (before optimization) """
|
||||
default_job_size = self.get_default_worker_batch_size()
|
||||
|
||||
# clear jobs if this is not the first time running
|
||||
if self.initialized:
|
||||
master_job = self.jobs[0]
|
||||
self.jobs = [master_job]
|
||||
self.jobs = []
|
||||
|
||||
batch_size = self.default_batch_size()
|
||||
for worker in self.get_workers():
|
||||
if worker.master:
|
||||
self.master_job().batch_size = default_job_size
|
||||
continue
|
||||
|
||||
batch_size = default_job_size
|
||||
self.jobs.append(Job(worker=worker, batch_size=batch_size))
|
||||
|
||||
def get_workers(self):
|
||||
filtered = []
|
||||
for worker in self.__workers:
|
||||
if worker.avg_ipm is not None and worker.avg_ipm <= 0:
|
||||
logger.warn(f"config reports invalid speed (0 ipm) for worker '{worker.uuid}', setting default of 1 ipm.\nplease re-benchmark")
|
||||
worker.avg_ipm = 1
|
||||
continue
|
||||
if worker.master and self.thin_client_mode:
|
||||
continue
|
||||
if worker.state != State.UNAVAILABLE:
|
||||
filtered.append(worker)
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def optimize_jobs(self, payload: json):
|
||||
"""
|
||||
The payload batch_size should be set to whatever the default worker batch_size would be.
|
||||
get_default_worker_batch_size() should return the proper value if the world is initialized
|
||||
default_batch_size() should return the proper value if the world is initialized
|
||||
Ex. 3 workers(including master): payload['batch_size'] should evaluate to 1
|
||||
"""
|
||||
|
||||
|
|
@ -496,15 +471,6 @@ class World:
|
|||
|
||||
job.batch_size = num_images_compensate
|
||||
|
||||
# TODO master batch_size cannot be < 1 or it will crash the entire generation.
|
||||
# It might be better to just inject a black image. (if master is that slow)
|
||||
master_job = self.master_job()
|
||||
if master_job.batch_size < 1:
|
||||
logger.warning("Master couldn't keep up... defaulting to 1 image")
|
||||
master_job.batch_size = 1
|
||||
|
||||
|
||||
|
||||
logger.info("Job distribution:")
|
||||
iterations = payload['n_iter']
|
||||
logger.info(f"{self.total_batch_size} * {iterations} iteration(s): {self.total_batch_size * iterations} images")
|
||||
|
|
|
|||
Loading…
Reference in New Issue