move rem handling before complementary fill calcs

pull/15/head
unknown 2023-06-05 06:43:46 -05:00
parent da2e32242e
commit b71eafe2c8
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
1 changed files with 19 additions and 14 deletions

View File

@ -409,6 +409,25 @@ class World:
for job in realtime_jobs:
job.batch_size = job.batch_size + images_per_job
#######################
# remainder handling #
#######################
# when total number of requested images was not cleanly divisible by world size then we tack the remainder on
remainder_images = self.total_batch_size - self.get_current_output_size()
if remainder_images >= 1:
logger.debug(f"The requested number of images({self.total_batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed")
realtime_jobs = self.realtime_jobs()
realtime_jobs.sort(key=lambda x: x.batch_size)
# round-robin distribute the remaining images
while remainder_images >= 1:
for job in realtime_jobs:
if remainder_images < 1:
break
job.batch_size += 1
remainder_images -= 1
#####################################
# complementary worker distribution #
#####################################
@ -444,21 +463,7 @@ class World:
logger.warning("Master couldn't keep up... defaulting to 1 image")
master_job.batch_size = 1
# if the total number of requested images is not cleanly divisible by the world size then we tack that on here
# *if that hasn't already been filled by complementary fill or the requirement that master's batch size be >= 1
remainder_images = self.total_batch_size - self.get_current_output_size()
if remainder_images >= 1:
logger.debug(f"The requested number of images({self.total_batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed")
realtime_jobs = self.realtime_jobs()
realtime_jobs.sort(key=lambda x: x.batch_size)
# round-robin distribute the remaining images
while remainder_images >= 1:
for job in realtime_jobs:
if remainder_images < 1:
break
job.batch_size += 1
remainder_images -= 1
logger.info("Job distribution:")
iterations = payload['n_iter']