diff --git a/scripts/spartan/World.py b/scripts/spartan/World.py index 2bf3ad9..8be854a 100644 --- a/scripts/spartan/World.py +++ b/scripts/spartan/World.py @@ -409,6 +409,25 @@ class World: for job in realtime_jobs: job.batch_size = job.batch_size + images_per_job + ####################### + # remainder handling # + ####################### + + # when total number of requested images was not cleanly divisible by world size then we tack the remainder on + remainder_images = self.total_batch_size - self.get_current_output_size() + if remainder_images >= 1: + logger.debug(f"The requested number of images({self.total_batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed") + + realtime_jobs = self.realtime_jobs() + realtime_jobs.sort(key=lambda x: x.batch_size) + # round-robin distribute the remaining images + while remainder_images >= 1: + for job in realtime_jobs: + if remainder_images < 1: + break + job.batch_size += 1 + remainder_images -= 1 + ##################################### # complementary worker distribution # ##################################### @@ -444,21 +463,7 @@ class World: logger.warning("Master couldn't keep up... defaulting to 1 image") master_job.batch_size = 1 - # if the total number of requested images is not cleanly divisible by the world size then we tack that on here - # *if that hasn't already been filled by complementary fill or the requirement that master's batch size be >= 1 - remainder_images = self.total_batch_size - self.get_current_output_size() - if remainder_images >= 1: - logger.debug(f"The requested number of images({self.total_batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed") - realtime_jobs = self.realtime_jobs() - realtime_jobs.sort(key=lambda x: x.batch_size) - # round-robin distribute the remaining images - while remainder_images >= 1: - for job in realtime_jobs: - if remainder_images < 1: - break - job.batch_size += 1 - remainder_images -= 1 logger.info("Job distribution:") iterations = payload['n_iter']