support batch sizes that do not divide cleanly
parent
81e4f9f9e9
commit
64df0c6f1c
|
|
@ -33,7 +33,6 @@ from modules.processing import fix_seed
|
|||
|
||||
# noinspection PyMissingOrEmptyDocstring
|
||||
class Script(scripts.Script):
|
||||
response_cache: json = None
|
||||
worker_threads: List[Thread] = []
|
||||
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
|
||||
verify_remotes = False if cmd_opts.distributed_skip_verify_remotes else True
|
||||
|
|
@ -231,7 +230,7 @@ class Script(scripts.Script):
|
|||
|
||||
spoofed_iteration = p.n_iter
|
||||
for worker in Script.world.workers:
|
||||
# if it fails here then that means that the response_cache global var is not being filled for some reason
|
||||
|
||||
expected_images = 1
|
||||
for job in Script.world.jobs:
|
||||
if job.worker == worker:
|
||||
|
|
@ -262,6 +261,7 @@ class Script(scripts.Script):
|
|||
injected_to_iteration = 0
|
||||
else:
|
||||
injected_to_iteration += 1
|
||||
worker.response = None
|
||||
|
||||
# generate and inject grid
|
||||
if opts.return_grid:
|
||||
|
|
@ -269,14 +269,6 @@ class Script(scripts.Script):
|
|||
processed_inject_image(image=grid, info_index=0, save_path_override=p.outpath_grids, iteration=spoofed_iteration, grid=True)
|
||||
|
||||
p.batch_size = len(processed.images)
|
||||
"""
|
||||
This ensures that we don't get redundant outputs in a certain case:
|
||||
We have 3 workers and we get 3 responses back.
|
||||
The user requests another 3, but due to job optimization one of the workers does not produce anything new.
|
||||
If we don't empty the response, the user will get back the two images they requested, but also one from before.
|
||||
"""
|
||||
worker.response = None
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from threading import Thread
|
|||
from inspect import getsourcefile
|
||||
from os.path import abspath
|
||||
from pathlib import Path
|
||||
import math
|
||||
from modules.processing import process_images, StableDiffusionProcessingTxt2Img
|
||||
# from modules.shared import cmd_opts
|
||||
import modules.shared as shared
|
||||
|
|
@ -93,12 +94,7 @@ class World:
|
|||
total_batch_size (int): The total number of images requested by the local/master sdwui instance.
|
||||
"""
|
||||
|
||||
world_size = self.get_world_size()
|
||||
if total_batch_size < world_size:
|
||||
self.total_batch_size = world_size
|
||||
logger.debug(f"Defaulting to a total batch size of '{world_size}' in order to accommodate all workers")
|
||||
else:
|
||||
self.total_batch_size = total_batch_size
|
||||
self.total_batch_size = total_batch_size
|
||||
|
||||
default_worker_batch_size = self.get_default_worker_batch_size()
|
||||
self.sync_master(batch_size=default_worker_batch_size)
|
||||
|
|
@ -118,7 +114,12 @@ class World:
|
|||
"""the amount of images/total images requested that a worker would compute if conditions were perfect and
|
||||
each worker generated at the same speed"""
|
||||
|
||||
return self.total_batch_size // self.get_world_size()
|
||||
quotient, remainder = divmod(self.total_batch_size, self.get_world_size())
|
||||
chosen = quotient if quotient > remainder else remainder
|
||||
per_worker_batch_size = math.ceil(chosen)
|
||||
logger.debug(f"default per-node batch-size is {per_worker_batch_size}")
|
||||
|
||||
return per_worker_batch_size
|
||||
|
||||
def get_world_size(self) -> int:
|
||||
"""
|
||||
|
|
@ -385,18 +386,22 @@ class World:
|
|||
# the maximum amount of images that a "slow" worker can produce in the slack space where other nodes are working
|
||||
# max_compensation = 4 currently unused
|
||||
images_per_job = None
|
||||
|
||||
images_checked = 0
|
||||
for job in self.jobs:
|
||||
|
||||
lag = self.job_stall(job.worker, payload=payload)
|
||||
|
||||
if lag < self.job_timeout or lag == 0:
|
||||
job.batch_size = payload['batch_size']
|
||||
images_checked += payload['batch_size']
|
||||
continue
|
||||
|
||||
logger.debug(f"worker '{job.worker.uuid}' would stall the image gallery by ~{lag:.2f}s\n")
|
||||
job.complementary = True
|
||||
deferred_images = deferred_images + payload['batch_size']
|
||||
if deferred_images + images_checked + payload['batch_size'] > self.total_batch_size:
|
||||
logger.debug(f"would go over actual requested size")
|
||||
else:
|
||||
deferred_images += payload['batch_size']
|
||||
job.batch_size = 0
|
||||
|
||||
####################################################
|
||||
|
|
@ -426,7 +431,8 @@ class World:
|
|||
|
||||
# in the case that this worker is now taking on what others workers would have been (if they were real-time)
|
||||
# this means that there will be more slack time for complementary nodes
|
||||
slack_time = slack_time + ((slack_time / payload['batch_size']) * images_per_job)
|
||||
if images_per_job is not None:
|
||||
slack_time = slack_time + ((slack_time / payload['batch_size']) * images_per_job)
|
||||
|
||||
# see how long it would take to produce only 1 image on this complementary worker
|
||||
fake_payload = copy.copy(payload)
|
||||
|
|
@ -443,6 +449,24 @@ class World:
|
|||
logger.warning("Master couldn't keep up... defaulting to 1 image")
|
||||
master_job.batch_size = 1
|
||||
|
||||
# if the total number of requested images is not cleanly divisible by the world size then we tack that on here
|
||||
# *if that hasn't already been filled by complementary fill or the requirement that master's batch size be >= 1
|
||||
remainder_images = self.total_batch_size - self.get_current_output_size()
|
||||
logger.debug(f"{remainder_images} = {self.total_batch_size} - {self.get_current_output_size()}")
|
||||
if remainder_images >= 1:
|
||||
logger.debug(f"The requested number of images({self.total_batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) and complementary jobs did not provide this missing image.")
|
||||
|
||||
# Gets the fastest job that has been assigned the least amount of images
|
||||
laziest_realtime_job = None
|
||||
for job in self.realtime_jobs():
|
||||
if laziest_realtime_job is None:
|
||||
laziest_realtime_job = job
|
||||
elif laziest_realtime_job.batch_size > job.batch_size:
|
||||
laziest_realtime_job = job
|
||||
|
||||
laziest_realtime_job.batch_size += remainder_images
|
||||
logger.debug(f"dispatched remainder image to worker '{laziest_realtime_job.worker.uuid}'")
|
||||
|
||||
logger.info("Job distribution:")
|
||||
for job in self.jobs:
|
||||
logger.info(f"worker '{job.worker.uuid}' - {job.batch_size} images")
|
||||
|
|
|
|||
Loading…
Reference in New Issue