From 34bf2893e4727275ce89149c2e8949d41493e8c7 Mon Sep 17 00:00:00 2001 From: papuSpartan <30642826+papuSpartan@users.noreply.github.com> Date: Thu, 18 May 2023 13:13:38 -0500 Subject: [PATCH] improve logging by using rich --- install.py | 4 +++ scripts/extension.py | 21 ++++++-------- scripts/spartan/Worker.py | 61 ++++++++++++++++++++++++--------------- scripts/spartan/World.py | 27 ++++++++--------- scripts/spartan/shared.py | 8 +++++ 5 files changed, 71 insertions(+), 50 deletions(-) create mode 100644 install.py diff --git a/install.py b/install.py new file mode 100644 index 0000000..a797b45 --- /dev/null +++ b/install.py @@ -0,0 +1,4 @@ +import launch + +if not launch.is_installed("rich"): + launch.run_pip("install rich", "requirements for distributed") diff --git a/scripts/extension.py b/scripts/extension.py index e5fac23..6443c37 100644 --- a/scripts/extension.py +++ b/scripts/extension.py @@ -23,6 +23,7 @@ import subprocess from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized from scripts.spartan.Worker import Worker, State from modules.shared import opts +from scripts.spartan.shared import logger # TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers? # TODO see if the current api has some sort of UUID generation functionality. @@ -66,7 +67,7 @@ class Script(scripts.Script): refresh_status_btn.style(size='sm') refresh_status_btn.click(Script.ui_connect_status, inputs=[], outputs=[jobs, status]) - status_tab.select(fn=Script.ui_connect_status, inputs=[], outputs=[jobs, status]) + # status_tab.select(fn=Script.ui_connect_status, inputs=[], outputs=[jobs, status]) with gradio.Tab('Utils'): refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints') @@ -92,7 +93,7 @@ class Script(scripts.Script): @staticmethod def ui_connect_benchmark_button(): - print("Redoing benchmarks...") + logger.info("Redoing benchmarks...") Script.world.benchmark(rebenchmark=True) @staticmethod @@ -124,14 +125,14 @@ class Script(scripts.Script): try: Script.world.interrupt_remotes() except AttributeError: - print("Nothing to interrupt, Distributed system not initialized") + logger.debug("Nothing to interrupt, Distributed system not initialized") @staticmethod def ui_connect_refresh_ckpts_btn(): try: Script.world.refresh_checkpoints() except AttributeError: - print("Distributed system not initialized") + logger.debug("Distributed system not initialized") @staticmethod def ui_connect_status(): @@ -153,7 +154,6 @@ class Script(scripts.Script): # init system if it isn't already except AttributeError as e: - print(e) # batch size will be clobbered later once an actual request is made anyway Script.initialize(initial_payload=None) return 'refresh!', 'refresh!' @@ -165,7 +165,7 @@ class Script(scripts.Script): # get master ipm by estimating based on worker speed master_elapsed = time.time() - Script.master_start - print(f"Took master {master_elapsed}s") + logger.debug(f"Took master {master_elapsed}s") # wait for response from all workers for thread in Script.worker_threads: @@ -177,7 +177,7 @@ class Script(scripts.Script): images: json = worker.response["images"] except TypeError: if worker.master is False: - print(f"Worker '{worker.uuid}' had nothing") + logger.debug(f"Worker '{worker.uuid}' had nothing") continue image_params: json = worker.response["parameters"] @@ -250,7 +250,7 @@ class Script(scripts.Script): if Script.world is None: if Script.verify_remotes is False: - print(f"WARNING: you have chosen to forego the verification of worker TLS certificates") + logger.info(f"WARNING: you have chosen to forego the verification of worker TLS certificates") urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # construct World @@ -264,7 +264,7 @@ class Script(scripts.Script): # update world or initialize and update if necessary try: Script.world.initialize(batch_size) - print("World initialized!") + logger.debug(f"World initialized!") except WorldAlreadyInitialized: Script.world.update_world(total_batch_size=batch_size) @@ -282,8 +282,6 @@ class Script(scripts.Script): payload = p.__dict__ payload['batch_size'] = Script.world.get_default_worker_batch_size() payload['scripts'] = None - # print(payload) - # print(opts.dumpjson()) # TODO api for some reason returns 200 even if something failed to be set. # for now we may have to make redundant GET requests to check if actually successful... @@ -310,7 +308,6 @@ class Script(scripts.Script): job.worker.loaded_model = name job.worker.loaded_vae = vae - # print(f"requesting {new_payload['batch_size']} images from worker '{job.worker.uuid}'\n") t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync,)) t.start() diff --git a/scripts/spartan/Worker.py b/scripts/spartan/Worker.py index faa2aa9..d313f29 100644 --- a/scripts/spartan/Worker.py +++ b/scripts/spartan/Worker.py @@ -8,7 +8,7 @@ from threading import Thread from webui import server_name from modules.shared import cmd_opts import gradio as gr -from scripts.spartan.shared import benchmark_payload +from scripts.spartan.shared import benchmark_payload, logger from enum import Enum @@ -224,8 +224,8 @@ class Worker: else: eta += (eta * abs((percent_difference / 100))) except KeyError: - print(f"Sampler '{payload['sampler_name']}' efficiency is not recorded.\n") - print(f"Sampler efficiency will be treated as equivalent to Euler A.") + logger.debug(f"Sampler '{payload['sampler_name']}' efficiency is not recorded.\n") + logger.debug(f"Sampler efficiency will be treated as equivalent to Euler A.") # TODO save and load each workers MPE before the end of session to workers.json. # That way initial estimations are more accurate from the second sdwui session onward @@ -233,9 +233,8 @@ class Worker: if len(self.eta_percent_error) > 0: correction = eta * (self.eta_mpe() / 100) - if cmd_opts.distributed_debug: - print(f"worker '{self.uuid}'s last ETA was off by {correction:.2f}%") - print(f"{self.uuid} eta before correction: ", eta) + logger.debug(f"worker '{self.uuid}'s last ETA was off by {correction:.2f}%") + logger.debug(f"{self.uuid} eta before correction: ", eta) # do regression if correction > 0: @@ -243,8 +242,7 @@ class Worker: else: eta += correction - if cmd_opts.distributed_debug: - print(f"{self.uuid} eta after correction: ", eta) + logger.debug(f"{self.uuid} eta after correction: ", eta) return eta except Exception as e: @@ -277,7 +275,7 @@ class Worker: free_vram = int(memory_response['free']) / (1024 * 1024 * 1024) total_vram = int(memory_response['total']) / (1024 * 1024 * 1024) - print(f"Worker '{self.uuid}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n") + logger.debug(f"Worker '{self.uuid}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n") self.free_vram = bytes(memory_response['free']) if sync_options is True: @@ -292,10 +290,29 @@ class Worker: if self.benchmarked: eta = self.batch_eta(payload=payload) - print(f"worker '{self.uuid}' predicts it will take {eta:.3f}s to generate {payload['batch_size']} image(" + logger.info(f"worker '{self.uuid}' predicts it will take {eta:.3f}s to generate {payload['batch_size']} image(" f"s) at a speed of {self.avg_ipm} ipm\n") try: + # import json + # def find_bad_keys(json_data): + # parsed_data = json.loads(json_data) + # bad_keys = [] + + # for key, value in parsed_data.items(): + # if isinstance(value, float): + # if value < -1e308 or value > 1e308: + # bad_keys.append(key) + + # return bad_keys + + # for key in find_bad_keys(json.dumps(payload)): + # logger.info(f"Bad key '{key}' found in payload with value '{payload[key]}'") + + # s_tmax can be float('inf') which is not serializable so we convert it to the max float value + if payload['s_tmax'] == float('inf'): + payload['s_tmax'] = 1e308 + start = time.time() response = requests.post( self.full_url("txt2img"), @@ -309,9 +326,8 @@ class Worker: self.response_time = time.time() - start variance = ((eta - self.response_time) / self.response_time) * 100 - if cmd_opts.distributed_debug: - print(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n") - print(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n") + logger.debug(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n") + logger.debug(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n") # if the variance is greater than 500% then we ignore it to prevent variation inflation if abs(variance) < 500: @@ -324,7 +340,7 @@ class Worker: else: # normal case self.eta_percent_error.append(variance) else: - print(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n") + logger.debug(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n") except Exception as e: if payload['batch_size'] == 0: @@ -333,7 +349,7 @@ class Worker: raise InvalidWorkerResponse(e) except requests.exceptions.ConnectTimeout: - print(f"\nTimed out waiting for worker '{self.uuid}' at {self}") + logger.info(f"\nTimed out waiting for worker '{self.uuid}' at {self}") self.state = State.IDLE return @@ -348,7 +364,7 @@ class Worker: samples = 2 # number of times to benchmark the remote / accuracy warmup_samples = 2 # number of samples to do before recording as a valid sample in order to "warm-up" - print(f"Benchmarking worker '{self.uuid}':\n") + logger.info(f"Benchmarking worker '{self.uuid}':\n") def ipm(seconds: float) -> float: """ @@ -376,15 +392,16 @@ class Worker: elapsed = time.time() - start sample_ipm = ipm(elapsed) except InvalidWorkerResponse as e: + # TODO print(e) raise gr.Error(e.__str__()) if i >= warmup_samples: - print(f"Sample {i - warmup_samples + 1}: Worker '{self.uuid}'({self}) - {sample_ipm:.2f} image(s) per " + logger.info(f"Sample {i - warmup_samples + 1}: Worker '{self.uuid}'({self}) - {sample_ipm:.2f} image(s) per " f"minute\n") results.append(sample_ipm) elif i == warmup_samples - 1: - print(f"{self.uuid} warming up\n") + logger.info(f"{self.uuid} warming up\n") # average the sample results for accuracy ipm_sum = 0 @@ -392,7 +409,7 @@ class Worker: ipm_sum += ipm avg_ipm = math.floor(ipm_sum / samples) - print(f"Worker '{self.uuid}' average ipm: {avg_ipm}") + logger.info(f"Worker '{self.uuid}' average ipm: {avg_ipm}") self.avg_ipm = avg_ipm # noinspection PyTypeChecker self.response = None @@ -410,8 +427,7 @@ class Worker: if response.status_code == 200: self.state = State.INTERRUPTED - if cmd_opts.distributed_debug: - print(f"successfully refreshed checkpoints for worker '{self.uuid}'") + logger.debug(f"successfully refreshed checkpoints for worker '{self.uuid}'") def interrupt(self): response = requests.post( @@ -422,5 +438,4 @@ class Worker: if response.status_code == 200: self.state = State.INTERRUPTED - if cmd_opts.distributed_debug: - print(f"successfully interrupted worker {self.uuid}") + logger.debug(f"successfully interrupted worker {self.uuid}") diff --git a/scripts/spartan/World.py b/scripts/spartan/World.py index 7d52710..b875618 100644 --- a/scripts/spartan/World.py +++ b/scripts/spartan/World.py @@ -15,9 +15,9 @@ from inspect import getsourcefile from os.path import abspath from pathlib import Path from modules.processing import process_images -from modules.shared import cmd_opts +# from modules.shared import cmd_opts from scripts.spartan.Worker import Worker -from scripts.spartan.shared import benchmark_payload +from scripts.spartan.shared import benchmark_payload, logger # from modules.errors import display import gradio as gr @@ -95,8 +95,7 @@ class World: world_size = self.get_world_size() if total_batch_size < world_size: self.total_batch_size = world_size - print(f"Total batch size should not be less than the number of workers.\n") - print(f"Defaulting to a total batch size of '{world_size}' in order to accommodate all workers") + logger.debug(f"Defaulting to a total batch size of '{world_size}' in order to accommodate all workers") else: self.total_batch_size = total_batch_size @@ -202,6 +201,8 @@ class World: """ global benchmark_payload + logger.info("Benchmarking workers...") + workers_info: dict = {} saved: bool = os.path.exists(self.worker_info_path) benchmark_payload_loaded: bool = False @@ -226,9 +227,7 @@ class World: benchmark_payload = workers_info[worker.uuid]['benchmark_payload'] benchmark_payload_loaded = True - if cmd_opts.distributed_debug: - print("loaded saved worker configuration:") - print(workers_info) + logger.debug(f"loaded saved worker configuration: \n{workers_info}") worker.avg_ipm = workers_info[worker.uuid]['avg_ipm'] worker.benchmarked = True @@ -352,7 +351,7 @@ class World: ipm = benchmark_payload['batch_size'] / (elapsed / 60) - print(f"Master benchmark took {elapsed}: {ipm} ipm") + logger.debug(f"Master benchmark took {elapsed}: {ipm} ipm") self.master().benchmarked = True return ipm @@ -393,7 +392,7 @@ class World: job.batch_size = payload['batch_size'] continue - print(f"worker '{job.worker.uuid}' would stall the image gallery by ~{lag:.2f}s\n") + logger.debug(f"worker '{job.worker.uuid}' would stall the image gallery by ~{lag:.2f}s\n") job.complementary = True deferred_images = deferred_images + payload['batch_size'] job.batch_size = 0 @@ -421,8 +420,7 @@ class World: slowest_active_worker = self.slowest_realtime_job().worker slack_time = slowest_active_worker.batch_eta(payload=payload) - if cmd_opts.distributed_debug: - print(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.uuid}'") + logger.debug(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.uuid}'") # in the case that this worker is now taking on what others workers would have been (if they were real-time) # this means that there will be more slack time for complementary nodes @@ -440,11 +438,10 @@ class World: # It might be better to just inject a black image. (if master is that slow) master_job = self.master_job() if master_job.batch_size < 1: - if cmd_opts.distributed_debug: - print("Master couldn't keep up... defaulting to 1 image") + logger.debug("Master couldn't keep up... defaulting to 1 image") master_job.batch_size = 1 - print("After job optimization, job layout is the following:") + logger.info("After job optimization, job layout is the following:") for job in self.jobs: - print(f"worker '{job.worker.uuid}' - {job.batch_size} images") + logger.info(f"worker '{job.worker.uuid}' - {job.batch_size} images") print() diff --git a/scripts/spartan/shared.py b/scripts/spartan/shared.py index 50ef1a9..f5d9c59 100644 --- a/scripts/spartan/shared.py +++ b/scripts/spartan/shared.py @@ -1,3 +1,11 @@ +import logging +from rich.logging import RichHandler +from modules.shared import cmd_opts + +log_level = 'DEBUG' if cmd_opts.distributed_debug else 'INFO' +logging.basicConfig(level=log_level, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]) +logger = logging.getLogger("rich") + benchmark_payload: dict = { "prompt": "A herd of cows grazing at the bottom of a sunny valley", "negative_prompt": "",