support batch sizes that do not divide cleanly

pull/15/head
unknown 2023-06-03 10:58:51 -05:00
parent 81e4f9f9e9
commit 64df0c6f1c
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
2 changed files with 36 additions and 20 deletions

View File

@ -33,7 +33,6 @@ from modules.processing import fix_seed
# noinspection PyMissingOrEmptyDocstring
class Script(scripts.Script):
response_cache: json = None
worker_threads: List[Thread] = []
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
verify_remotes = False if cmd_opts.distributed_skip_verify_remotes else True
@ -231,7 +230,7 @@ class Script(scripts.Script):
spoofed_iteration = p.n_iter
for worker in Script.world.workers:
# if it fails here then that means that the response_cache global var is not being filled for some reason
expected_images = 1
for job in Script.world.jobs:
if job.worker == worker:
@ -262,6 +261,7 @@ class Script(scripts.Script):
injected_to_iteration = 0
else:
injected_to_iteration += 1
worker.response = None
# generate and inject grid
if opts.return_grid:
@ -269,14 +269,6 @@ class Script(scripts.Script):
processed_inject_image(image=grid, info_index=0, save_path_override=p.outpath_grids, iteration=spoofed_iteration, grid=True)
p.batch_size = len(processed.images)
"""
This ensures that we don't get redundant outputs in a certain case:
We have 3 workers and we get 3 responses back.
The user requests another 3, but due to job optimization one of the workers does not produce anything new.
If we don't empty the response, the user will get back the two images they requested, but also one from before.
"""
worker.response = None
return
@staticmethod

View File

@ -14,6 +14,7 @@ from threading import Thread
from inspect import getsourcefile
from os.path import abspath
from pathlib import Path
import math
from modules.processing import process_images, StableDiffusionProcessingTxt2Img
# from modules.shared import cmd_opts
import modules.shared as shared
@ -93,12 +94,7 @@ class World:
total_batch_size (int): The total number of images requested by the local/master sdwui instance.
"""
world_size = self.get_world_size()
if total_batch_size < world_size:
self.total_batch_size = world_size
logger.debug(f"Defaulting to a total batch size of '{world_size}' in order to accommodate all workers")
else:
self.total_batch_size = total_batch_size
self.total_batch_size = total_batch_size
default_worker_batch_size = self.get_default_worker_batch_size()
self.sync_master(batch_size=default_worker_batch_size)
@ -118,7 +114,12 @@ class World:
"""the amount of images/total images requested that a worker would compute if conditions were perfect and
each worker generated at the same speed"""
return self.total_batch_size // self.get_world_size()
quotient, remainder = divmod(self.total_batch_size, self.get_world_size())
chosen = quotient if quotient > remainder else remainder
per_worker_batch_size = math.ceil(chosen)
logger.debug(f"default per-node batch-size is {per_worker_batch_size}")
return per_worker_batch_size
def get_world_size(self) -> int:
"""
@ -385,18 +386,22 @@ class World:
# the maximum amount of images that a "slow" worker can produce in the slack space where other nodes are working
# max_compensation = 4 currently unused
images_per_job = None
images_checked = 0
for job in self.jobs:
lag = self.job_stall(job.worker, payload=payload)
if lag < self.job_timeout or lag == 0:
job.batch_size = payload['batch_size']
images_checked += payload['batch_size']
continue
logger.debug(f"worker '{job.worker.uuid}' would stall the image gallery by ~{lag:.2f}s\n")
job.complementary = True
deferred_images = deferred_images + payload['batch_size']
if deferred_images + images_checked + payload['batch_size'] > self.total_batch_size:
logger.debug(f"would go over actual requested size")
else:
deferred_images += payload['batch_size']
job.batch_size = 0
####################################################
@ -426,7 +431,8 @@ class World:
# in the case that this worker is now taking on what others workers would have been (if they were real-time)
# this means that there will be more slack time for complementary nodes
slack_time = slack_time + ((slack_time / payload['batch_size']) * images_per_job)
if images_per_job is not None:
slack_time = slack_time + ((slack_time / payload['batch_size']) * images_per_job)
# see how long it would take to produce only 1 image on this complementary worker
fake_payload = copy.copy(payload)
@ -443,6 +449,24 @@ class World:
logger.warning("Master couldn't keep up... defaulting to 1 image")
master_job.batch_size = 1
# if the total number of requested images is not cleanly divisible by the world size then we tack that on here
# *if that hasn't already been filled by complementary fill or the requirement that master's batch size be >= 1
remainder_images = self.total_batch_size - self.get_current_output_size()
logger.debug(f"{remainder_images} = {self.total_batch_size} - {self.get_current_output_size()}")
if remainder_images >= 1:
logger.debug(f"The requested number of images({self.total_batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) and complementary jobs did not provide this missing image.")
# Gets the fastest job that has been assigned the least amount of images
laziest_realtime_job = None
for job in self.realtime_jobs():
if laziest_realtime_job is None:
laziest_realtime_job = job
elif laziest_realtime_job.batch_size > job.batch_size:
laziest_realtime_job = job
laziest_realtime_job.batch_size += remainder_images
logger.debug(f"dispatched remainder image to worker '{laziest_realtime_job.worker.uuid}'")
logger.info("Job distribution:")
for job in self.jobs:
logger.info(f"worker '{job.worker.uuid}' - {job.batch_size} images")