diff --git a/scripts/extension.py b/scripts/extension.py index d93e2be..fb4a1f2 100644 --- a/scripts/extension.py +++ b/scripts/extension.py @@ -159,7 +159,7 @@ class Script(scripts.Script): try: images: json = job.worker.response["images"] # if we for some reason get more than we asked for - if job.batch_size < len(images): + if (job.batch_size * p.n_iter) < len(images): logger.debug(f"Requested {job.batch_size} images from '{job.worker.uuid}', got {len(images)}") if donor_worker is None: diff --git a/scripts/spartan/Worker.py b/scripts/spartan/Worker.py index 56431fd..aecc0ae 100644 --- a/scripts/spartan/Worker.py +++ b/scripts/spartan/Worker.py @@ -10,13 +10,14 @@ from threading import Thread from webui import server_name from modules.shared import cmd_opts import gradio as gr -from scripts.spartan.shared import benchmark_payload, logger, warmup_samples +from scripts.spartan.shared import logger, warmup_samples from enum import Enum import json import base64 import queue from modules.shared import state as master_state from modules.api.api import encode_pil_to_base64 +import scripts.spartan.shared as sh class InvalidWorkerResponse(Exception): @@ -128,9 +129,6 @@ class Worker: Stores the payload used to benchmark the world and certain attributes of the worker. These things are used to draw certain conclusions after the first session. - Args: - benchmark_payload (dict): The payload used in the benchmark. - Returns: dict: Worker info, including how it was benchmarked. """ @@ -209,7 +207,7 @@ class Worker: # if worker has not yet been benchmarked then eta = (num_images / self.avg_ipm) * 60 # show effect of increased step size - real_steps_to_benched = steps / benchmark_payload['steps'] + real_steps_to_benched = steps / sh.benchmark_payload['steps'] eta = eta * real_steps_to_benched # show effect of high-res fix @@ -219,7 +217,7 @@ class Worker: # show effect of image size real_pix_to_benched = (payload['width'] * payload['height'])\ - / (benchmark_payload['width'] * benchmark_payload['height']) + / (sh.benchmark_payload['width'] * sh.benchmark_payload['height']) eta = eta * real_pix_to_benched # show effect of using a sampler other than euler a @@ -431,7 +429,7 @@ class Worker: float: Images per minute """ - return benchmark_payload['batch_size'] / (seconds / 60) + return sh.benchmark_payload['batch_size'] / (seconds / 60) results: List[float] = [] # it's seems to be lower for the first couple of generations @@ -442,7 +440,7 @@ class Worker: self.response = None return 0 - t = Thread(target=self.request, args=(benchmark_payload, None, False,), name=f"{self.uuid}_benchmark_request") + t = Thread(target=self.request, args=(sh.benchmark_payload, None, False,), name=f"{self.uuid}_benchmark_request") try: # if the worker is unreachable/offline then handle that here t.start() start = time.time() diff --git a/scripts/spartan/World.py b/scripts/spartan/World.py index fdc3b32..2465ab0 100644 --- a/scripts/spartan/World.py +++ b/scripts/spartan/World.py @@ -17,7 +17,8 @@ from pathlib import Path from modules.processing import process_images, StableDiffusionProcessingTxt2Img import modules.shared as shared from scripts.spartan.Worker import Worker, State -from scripts.spartan.shared import logger, warmup_samples, benchmark_payload +from scripts.spartan.shared import logger, warmup_samples +import scripts.spartan.shared as sh class NotBenchmarked(Exception): @@ -170,7 +171,6 @@ class World: """ Attempts to benchmark all workers a part of the world. """ - from scripts.spartan.shared import benchmark_payload workers_info: dict = {} saved: bool = os.path.exists(self.worker_info_path) @@ -186,8 +186,8 @@ class World: if saved: with open(self.worker_info_path, 'r') as worker_info_file: workers_info = json.load(worker_info_file) - benchmark_payload = workers_info['benchmark_payload'] - logger.info(f"Using saved benchmark config:\n{benchmark_payload}") + sh.benchmark_payload = workers_info['benchmark_payload'] + logger.info(f"Using saved benchmark config:\n{sh.benchmark_payload}") saved = False workers = self.get_workers() @@ -200,8 +200,8 @@ class World: with open(self.worker_info_path, 'r') as worker_info_file: try: workers_info = json.load(worker_info_file) - benchmark_payload = workers_info['benchmark_payload'] - logger.info(f"Using saved benchmark config:\n{benchmark_payload}") + sh.benchmark_payload = workers_info['benchmark_payload'] + logger.info(f"Using saved benchmark config:\n{sh.benchmark_payload}") except json.JSONDecodeError: logger.error(f"workers.json is not valid JSON, regenerating") rebenchmark = True @@ -238,7 +238,7 @@ class World: for worker in unbenched_workers: workers_info.update(worker.info()) - workers_info.update({'benchmark_payload': benchmark_payload}) + workers_info.update({'benchmark_payload': sh.benchmark_payload}) # save benchmark results to workers.json json.dump(workers_info, worker_info_file, indent=3) @@ -348,8 +348,8 @@ class World: # wrap our benchmark payload master_bench_payload = StableDiffusionProcessingTxt2Img() - for key in benchmark_payload: - setattr(master_bench_payload, key, benchmark_payload[key]) + for key in sh.benchmark_payload: + setattr(master_bench_payload, key, sh.benchmark_payload[key]) # Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to. master_bench_payload.do_not_save_samples = True @@ -362,7 +362,7 @@ class World: process_images(master_bench_payload) elapsed = time.time() - start - ipm = benchmark_payload['batch_size'] / (elapsed / 60) + ipm = sh.benchmark_payload['batch_size'] / (elapsed / 60) logger.debug(f"Master benchmark took {elapsed:.2f}: {ipm:.2f} ipm") self.master().benchmarked = True