diff --git a/scripts/extension.py b/scripts/extension.py index d85d6bc..a9d3ada 100644 --- a/scripts/extension.py +++ b/scripts/extension.py @@ -33,7 +33,6 @@ from modules.processing import fix_seed # noinspection PyMissingOrEmptyDocstring class Script(scripts.Script): - response_cache: json = None worker_threads: List[Thread] = [] # Whether to verify worker certificates. Can be useful if your remotes are self-signed. verify_remotes = False if cmd_opts.distributed_skip_verify_remotes else True @@ -231,7 +230,7 @@ class Script(scripts.Script): spoofed_iteration = p.n_iter for worker in Script.world.workers: - # if it fails here then that means that the response_cache global var is not being filled for some reason + expected_images = 1 for job in Script.world.jobs: if job.worker == worker: @@ -262,6 +261,7 @@ class Script(scripts.Script): injected_to_iteration = 0 else: injected_to_iteration += 1 + worker.response = None # generate and inject grid if opts.return_grid: @@ -269,14 +269,6 @@ class Script(scripts.Script): processed_inject_image(image=grid, info_index=0, save_path_override=p.outpath_grids, iteration=spoofed_iteration, grid=True) p.batch_size = len(processed.images) - """ - This ensures that we don't get redundant outputs in a certain case: - We have 3 workers and we get 3 responses back. - The user requests another 3, but due to job optimization one of the workers does not produce anything new. - If we don't empty the response, the user will get back the two images they requested, but also one from before. - """ - worker.response = None - return @staticmethod diff --git a/scripts/spartan/World.py b/scripts/spartan/World.py index e4e5f43..69ec917 100644 --- a/scripts/spartan/World.py +++ b/scripts/spartan/World.py @@ -14,6 +14,7 @@ from threading import Thread from inspect import getsourcefile from os.path import abspath from pathlib import Path +import math from modules.processing import process_images, StableDiffusionProcessingTxt2Img # from modules.shared import cmd_opts import modules.shared as shared @@ -93,12 +94,7 @@ class World: total_batch_size (int): The total number of images requested by the local/master sdwui instance. """ - world_size = self.get_world_size() - if total_batch_size < world_size: - self.total_batch_size = world_size - logger.debug(f"Defaulting to a total batch size of '{world_size}' in order to accommodate all workers") - else: - self.total_batch_size = total_batch_size + self.total_batch_size = total_batch_size default_worker_batch_size = self.get_default_worker_batch_size() self.sync_master(batch_size=default_worker_batch_size) @@ -118,7 +114,12 @@ class World: """the amount of images/total images requested that a worker would compute if conditions were perfect and each worker generated at the same speed""" - return self.total_batch_size // self.get_world_size() + quotient, remainder = divmod(self.total_batch_size, self.get_world_size()) + chosen = quotient if quotient > remainder else remainder + per_worker_batch_size = math.ceil(chosen) + logger.debug(f"default per-node batch-size is {per_worker_batch_size}") + + return per_worker_batch_size def get_world_size(self) -> int: """ @@ -385,18 +386,22 @@ class World: # the maximum amount of images that a "slow" worker can produce in the slack space where other nodes are working # max_compensation = 4 currently unused images_per_job = None - + images_checked = 0 for job in self.jobs: lag = self.job_stall(job.worker, payload=payload) if lag < self.job_timeout or lag == 0: job.batch_size = payload['batch_size'] + images_checked += payload['batch_size'] continue logger.debug(f"worker '{job.worker.uuid}' would stall the image gallery by ~{lag:.2f}s\n") job.complementary = True - deferred_images = deferred_images + payload['batch_size'] + if deferred_images + images_checked + payload['batch_size'] > self.total_batch_size: + logger.debug(f"would go over actual requested size") + else: + deferred_images += payload['batch_size'] job.batch_size = 0 #################################################### @@ -426,7 +431,8 @@ class World: # in the case that this worker is now taking on what others workers would have been (if they were real-time) # this means that there will be more slack time for complementary nodes - slack_time = slack_time + ((slack_time / payload['batch_size']) * images_per_job) + if images_per_job is not None: + slack_time = slack_time + ((slack_time / payload['batch_size']) * images_per_job) # see how long it would take to produce only 1 image on this complementary worker fake_payload = copy.copy(payload) @@ -443,6 +449,24 @@ class World: logger.warning("Master couldn't keep up... defaulting to 1 image") master_job.batch_size = 1 + # if the total number of requested images is not cleanly divisible by the world size then we tack that on here + # *if that hasn't already been filled by complementary fill or the requirement that master's batch size be >= 1 + remainder_images = self.total_batch_size - self.get_current_output_size() + logger.debug(f"{remainder_images} = {self.total_batch_size} - {self.get_current_output_size()}") + if remainder_images >= 1: + logger.debug(f"The requested number of images({self.total_batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) and complementary jobs did not provide this missing image.") + + # Gets the fastest job that has been assigned the least amount of images + laziest_realtime_job = None + for job in self.realtime_jobs(): + if laziest_realtime_job is None: + laziest_realtime_job = job + elif laziest_realtime_job.batch_size > job.batch_size: + laziest_realtime_job = job + + laziest_realtime_job.batch_size += remainder_images + logger.debug(f"dispatched remainder image to worker '{laziest_realtime_job.worker.uuid}'") + logger.info("Job distribution:") for job in self.jobs: logger.info(f"worker '{job.worker.uuid}' - {job.batch_size} images")