improve logging by using rich

pull/2/head
papuSpartan 2023-05-18 13:13:38 -05:00
parent bdf758f1c4
commit 34bf2893e4
5 changed files with 71 additions and 50 deletions

4
install.py Normal file
View File

@ -0,0 +1,4 @@
import launch
if not launch.is_installed("rich"):
launch.run_pip("install rich", "requirements for distributed")

View File

@ -23,6 +23,7 @@ import subprocess
from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
from scripts.spartan.Worker import Worker, State from scripts.spartan.Worker import Worker, State
from modules.shared import opts from modules.shared import opts
from scripts.spartan.shared import logger
# TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers? # TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
# TODO see if the current api has some sort of UUID generation functionality. # TODO see if the current api has some sort of UUID generation functionality.
@ -66,7 +67,7 @@ class Script(scripts.Script):
refresh_status_btn.style(size='sm') refresh_status_btn.style(size='sm')
refresh_status_btn.click(Script.ui_connect_status, inputs=[], outputs=[jobs, status]) refresh_status_btn.click(Script.ui_connect_status, inputs=[], outputs=[jobs, status])
status_tab.select(fn=Script.ui_connect_status, inputs=[], outputs=[jobs, status]) # status_tab.select(fn=Script.ui_connect_status, inputs=[], outputs=[jobs, status])
with gradio.Tab('Utils'): with gradio.Tab('Utils'):
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints') refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
@ -92,7 +93,7 @@ class Script(scripts.Script):
@staticmethod @staticmethod
def ui_connect_benchmark_button(): def ui_connect_benchmark_button():
print("Redoing benchmarks...") logger.info("Redoing benchmarks...")
Script.world.benchmark(rebenchmark=True) Script.world.benchmark(rebenchmark=True)
@staticmethod @staticmethod
@ -124,14 +125,14 @@ class Script(scripts.Script):
try: try:
Script.world.interrupt_remotes() Script.world.interrupt_remotes()
except AttributeError: except AttributeError:
print("Nothing to interrupt, Distributed system not initialized") logger.debug("Nothing to interrupt, Distributed system not initialized")
@staticmethod @staticmethod
def ui_connect_refresh_ckpts_btn(): def ui_connect_refresh_ckpts_btn():
try: try:
Script.world.refresh_checkpoints() Script.world.refresh_checkpoints()
except AttributeError: except AttributeError:
print("Distributed system not initialized") logger.debug("Distributed system not initialized")
@staticmethod @staticmethod
def ui_connect_status(): def ui_connect_status():
@ -153,7 +154,6 @@ class Script(scripts.Script):
# init system if it isn't already # init system if it isn't already
except AttributeError as e: except AttributeError as e:
print(e)
# batch size will be clobbered later once an actual request is made anyway # batch size will be clobbered later once an actual request is made anyway
Script.initialize(initial_payload=None) Script.initialize(initial_payload=None)
return 'refresh!', 'refresh!' return 'refresh!', 'refresh!'
@ -165,7 +165,7 @@ class Script(scripts.Script):
# get master ipm by estimating based on worker speed # get master ipm by estimating based on worker speed
master_elapsed = time.time() - Script.master_start master_elapsed = time.time() - Script.master_start
print(f"Took master {master_elapsed}s") logger.debug(f"Took master {master_elapsed}s")
# wait for response from all workers # wait for response from all workers
for thread in Script.worker_threads: for thread in Script.worker_threads:
@ -177,7 +177,7 @@ class Script(scripts.Script):
images: json = worker.response["images"] images: json = worker.response["images"]
except TypeError: except TypeError:
if worker.master is False: if worker.master is False:
print(f"Worker '{worker.uuid}' had nothing") logger.debug(f"Worker '{worker.uuid}' had nothing")
continue continue
image_params: json = worker.response["parameters"] image_params: json = worker.response["parameters"]
@ -250,7 +250,7 @@ class Script(scripts.Script):
if Script.world is None: if Script.world is None:
if Script.verify_remotes is False: if Script.verify_remotes is False:
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates") logger.info(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# construct World # construct World
@ -264,7 +264,7 @@ class Script(scripts.Script):
# update world or initialize and update if necessary # update world or initialize and update if necessary
try: try:
Script.world.initialize(batch_size) Script.world.initialize(batch_size)
print("World initialized!") logger.debug(f"World initialized!")
except WorldAlreadyInitialized: except WorldAlreadyInitialized:
Script.world.update_world(total_batch_size=batch_size) Script.world.update_world(total_batch_size=batch_size)
@ -282,8 +282,6 @@ class Script(scripts.Script):
payload = p.__dict__ payload = p.__dict__
payload['batch_size'] = Script.world.get_default_worker_batch_size() payload['batch_size'] = Script.world.get_default_worker_batch_size()
payload['scripts'] = None payload['scripts'] = None
# print(payload)
# print(opts.dumpjson())
# TODO api for some reason returns 200 even if something failed to be set. # TODO api for some reason returns 200 even if something failed to be set.
# for now we may have to make redundant GET requests to check if actually successful... # for now we may have to make redundant GET requests to check if actually successful...
@ -310,7 +308,6 @@ class Script(scripts.Script):
job.worker.loaded_model = name job.worker.loaded_model = name
job.worker.loaded_vae = vae job.worker.loaded_vae = vae
# print(f"requesting {new_payload['batch_size']} images from worker '{job.worker.uuid}'\n")
t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync,)) t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync,))
t.start() t.start()

View File

@ -8,7 +8,7 @@ from threading import Thread
from webui import server_name from webui import server_name
from modules.shared import cmd_opts from modules.shared import cmd_opts
import gradio as gr import gradio as gr
from scripts.spartan.shared import benchmark_payload from scripts.spartan.shared import benchmark_payload, logger
from enum import Enum from enum import Enum
@ -224,8 +224,8 @@ class Worker:
else: else:
eta += (eta * abs((percent_difference / 100))) eta += (eta * abs((percent_difference / 100)))
except KeyError: except KeyError:
print(f"Sampler '{payload['sampler_name']}' efficiency is not recorded.\n") logger.debug(f"Sampler '{payload['sampler_name']}' efficiency is not recorded.\n")
print(f"Sampler efficiency will be treated as equivalent to Euler A.") logger.debug(f"Sampler efficiency will be treated as equivalent to Euler A.")
# TODO save and load each workers MPE before the end of session to workers.json. # TODO save and load each workers MPE before the end of session to workers.json.
# That way initial estimations are more accurate from the second sdwui session onward # That way initial estimations are more accurate from the second sdwui session onward
@ -233,9 +233,8 @@ class Worker:
if len(self.eta_percent_error) > 0: if len(self.eta_percent_error) > 0:
correction = eta * (self.eta_mpe() / 100) correction = eta * (self.eta_mpe() / 100)
if cmd_opts.distributed_debug: logger.debug(f"worker '{self.uuid}'s last ETA was off by {correction:.2f}%")
print(f"worker '{self.uuid}'s last ETA was off by {correction:.2f}%") logger.debug(f"{self.uuid} eta before correction: ", eta)
print(f"{self.uuid} eta before correction: ", eta)
# do regression # do regression
if correction > 0: if correction > 0:
@ -243,8 +242,7 @@ class Worker:
else: else:
eta += correction eta += correction
if cmd_opts.distributed_debug: logger.debug(f"{self.uuid} eta after correction: ", eta)
print(f"{self.uuid} eta after correction: ", eta)
return eta return eta
except Exception as e: except Exception as e:
@ -277,7 +275,7 @@ class Worker:
free_vram = int(memory_response['free']) / (1024 * 1024 * 1024) free_vram = int(memory_response['free']) / (1024 * 1024 * 1024)
total_vram = int(memory_response['total']) / (1024 * 1024 * 1024) total_vram = int(memory_response['total']) / (1024 * 1024 * 1024)
print(f"Worker '{self.uuid}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n") logger.debug(f"Worker '{self.uuid}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n")
self.free_vram = bytes(memory_response['free']) self.free_vram = bytes(memory_response['free'])
if sync_options is True: if sync_options is True:
@ -292,10 +290,29 @@ class Worker:
if self.benchmarked: if self.benchmarked:
eta = self.batch_eta(payload=payload) eta = self.batch_eta(payload=payload)
print(f"worker '{self.uuid}' predicts it will take {eta:.3f}s to generate {payload['batch_size']} image(" logger.info(f"worker '{self.uuid}' predicts it will take {eta:.3f}s to generate {payload['batch_size']} image("
f"s) at a speed of {self.avg_ipm} ipm\n") f"s) at a speed of {self.avg_ipm} ipm\n")
try: try:
# import json
# def find_bad_keys(json_data):
# parsed_data = json.loads(json_data)
# bad_keys = []
# for key, value in parsed_data.items():
# if isinstance(value, float):
# if value < -1e308 or value > 1e308:
# bad_keys.append(key)
# return bad_keys
# for key in find_bad_keys(json.dumps(payload)):
# logger.info(f"Bad key '{key}' found in payload with value '{payload[key]}'")
# s_tmax can be float('inf') which is not serializable so we convert it to the max float value
if payload['s_tmax'] == float('inf'):
payload['s_tmax'] = 1e308
start = time.time() start = time.time()
response = requests.post( response = requests.post(
self.full_url("txt2img"), self.full_url("txt2img"),
@ -309,9 +326,8 @@ class Worker:
self.response_time = time.time() - start self.response_time = time.time() - start
variance = ((eta - self.response_time) / self.response_time) * 100 variance = ((eta - self.response_time) / self.response_time) * 100
if cmd_opts.distributed_debug: logger.debug(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n")
print(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n") logger.debug(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
print(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
# if the variance is greater than 500% then we ignore it to prevent variation inflation # if the variance is greater than 500% then we ignore it to prevent variation inflation
if abs(variance) < 500: if abs(variance) < 500:
@ -324,7 +340,7 @@ class Worker:
else: # normal case else: # normal case
self.eta_percent_error.append(variance) self.eta_percent_error.append(variance)
else: else:
print(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n") logger.debug(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
except Exception as e: except Exception as e:
if payload['batch_size'] == 0: if payload['batch_size'] == 0:
@ -333,7 +349,7 @@ class Worker:
raise InvalidWorkerResponse(e) raise InvalidWorkerResponse(e)
except requests.exceptions.ConnectTimeout: except requests.exceptions.ConnectTimeout:
print(f"\nTimed out waiting for worker '{self.uuid}' at {self}") logger.info(f"\nTimed out waiting for worker '{self.uuid}' at {self}")
self.state = State.IDLE self.state = State.IDLE
return return
@ -348,7 +364,7 @@ class Worker:
samples = 2 # number of times to benchmark the remote / accuracy samples = 2 # number of times to benchmark the remote / accuracy
warmup_samples = 2 # number of samples to do before recording as a valid sample in order to "warm-up" warmup_samples = 2 # number of samples to do before recording as a valid sample in order to "warm-up"
print(f"Benchmarking worker '{self.uuid}':\n") logger.info(f"Benchmarking worker '{self.uuid}':\n")
def ipm(seconds: float) -> float: def ipm(seconds: float) -> float:
""" """
@ -376,15 +392,16 @@ class Worker:
elapsed = time.time() - start elapsed = time.time() - start
sample_ipm = ipm(elapsed) sample_ipm = ipm(elapsed)
except InvalidWorkerResponse as e: except InvalidWorkerResponse as e:
# TODO
print(e) print(e)
raise gr.Error(e.__str__()) raise gr.Error(e.__str__())
if i >= warmup_samples: if i >= warmup_samples:
print(f"Sample {i - warmup_samples + 1}: Worker '{self.uuid}'({self}) - {sample_ipm:.2f} image(s) per " logger.info(f"Sample {i - warmup_samples + 1}: Worker '{self.uuid}'({self}) - {sample_ipm:.2f} image(s) per "
f"minute\n") f"minute\n")
results.append(sample_ipm) results.append(sample_ipm)
elif i == warmup_samples - 1: elif i == warmup_samples - 1:
print(f"{self.uuid} warming up\n") logger.info(f"{self.uuid} warming up\n")
# average the sample results for accuracy # average the sample results for accuracy
ipm_sum = 0 ipm_sum = 0
@ -392,7 +409,7 @@ class Worker:
ipm_sum += ipm ipm_sum += ipm
avg_ipm = math.floor(ipm_sum / samples) avg_ipm = math.floor(ipm_sum / samples)
print(f"Worker '{self.uuid}' average ipm: {avg_ipm}") logger.info(f"Worker '{self.uuid}' average ipm: {avg_ipm}")
self.avg_ipm = avg_ipm self.avg_ipm = avg_ipm
# noinspection PyTypeChecker # noinspection PyTypeChecker
self.response = None self.response = None
@ -410,8 +427,7 @@ class Worker:
if response.status_code == 200: if response.status_code == 200:
self.state = State.INTERRUPTED self.state = State.INTERRUPTED
if cmd_opts.distributed_debug: logger.debug(f"successfully refreshed checkpoints for worker '{self.uuid}'")
print(f"successfully refreshed checkpoints for worker '{self.uuid}'")
def interrupt(self): def interrupt(self):
response = requests.post( response = requests.post(
@ -422,5 +438,4 @@ class Worker:
if response.status_code == 200: if response.status_code == 200:
self.state = State.INTERRUPTED self.state = State.INTERRUPTED
if cmd_opts.distributed_debug: logger.debug(f"successfully interrupted worker {self.uuid}")
print(f"successfully interrupted worker {self.uuid}")

View File

@ -15,9 +15,9 @@ from inspect import getsourcefile
from os.path import abspath from os.path import abspath
from pathlib import Path from pathlib import Path
from modules.processing import process_images from modules.processing import process_images
from modules.shared import cmd_opts # from modules.shared import cmd_opts
from scripts.spartan.Worker import Worker from scripts.spartan.Worker import Worker
from scripts.spartan.shared import benchmark_payload from scripts.spartan.shared import benchmark_payload, logger
# from modules.errors import display # from modules.errors import display
import gradio as gr import gradio as gr
@ -95,8 +95,7 @@ class World:
world_size = self.get_world_size() world_size = self.get_world_size()
if total_batch_size < world_size: if total_batch_size < world_size:
self.total_batch_size = world_size self.total_batch_size = world_size
print(f"Total batch size should not be less than the number of workers.\n") logger.debug(f"Defaulting to a total batch size of '{world_size}' in order to accommodate all workers")
print(f"Defaulting to a total batch size of '{world_size}' in order to accommodate all workers")
else: else:
self.total_batch_size = total_batch_size self.total_batch_size = total_batch_size
@ -202,6 +201,8 @@ class World:
""" """
global benchmark_payload global benchmark_payload
logger.info("Benchmarking workers...")
workers_info: dict = {} workers_info: dict = {}
saved: bool = os.path.exists(self.worker_info_path) saved: bool = os.path.exists(self.worker_info_path)
benchmark_payload_loaded: bool = False benchmark_payload_loaded: bool = False
@ -226,9 +227,7 @@ class World:
benchmark_payload = workers_info[worker.uuid]['benchmark_payload'] benchmark_payload = workers_info[worker.uuid]['benchmark_payload']
benchmark_payload_loaded = True benchmark_payload_loaded = True
if cmd_opts.distributed_debug: logger.debug(f"loaded saved worker configuration: \n{workers_info}")
print("loaded saved worker configuration:")
print(workers_info)
worker.avg_ipm = workers_info[worker.uuid]['avg_ipm'] worker.avg_ipm = workers_info[worker.uuid]['avg_ipm']
worker.benchmarked = True worker.benchmarked = True
@ -352,7 +351,7 @@ class World:
ipm = benchmark_payload['batch_size'] / (elapsed / 60) ipm = benchmark_payload['batch_size'] / (elapsed / 60)
print(f"Master benchmark took {elapsed}: {ipm} ipm") logger.debug(f"Master benchmark took {elapsed}: {ipm} ipm")
self.master().benchmarked = True self.master().benchmarked = True
return ipm return ipm
@ -393,7 +392,7 @@ class World:
job.batch_size = payload['batch_size'] job.batch_size = payload['batch_size']
continue continue
print(f"worker '{job.worker.uuid}' would stall the image gallery by ~{lag:.2f}s\n") logger.debug(f"worker '{job.worker.uuid}' would stall the image gallery by ~{lag:.2f}s\n")
job.complementary = True job.complementary = True
deferred_images = deferred_images + payload['batch_size'] deferred_images = deferred_images + payload['batch_size']
job.batch_size = 0 job.batch_size = 0
@ -421,8 +420,7 @@ class World:
slowest_active_worker = self.slowest_realtime_job().worker slowest_active_worker = self.slowest_realtime_job().worker
slack_time = slowest_active_worker.batch_eta(payload=payload) slack_time = slowest_active_worker.batch_eta(payload=payload)
if cmd_opts.distributed_debug: logger.debug(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.uuid}'")
print(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.uuid}'")
# in the case that this worker is now taking on what others workers would have been (if they were real-time) # in the case that this worker is now taking on what others workers would have been (if they were real-time)
# this means that there will be more slack time for complementary nodes # this means that there will be more slack time for complementary nodes
@ -440,11 +438,10 @@ class World:
# It might be better to just inject a black image. (if master is that slow) # It might be better to just inject a black image. (if master is that slow)
master_job = self.master_job() master_job = self.master_job()
if master_job.batch_size < 1: if master_job.batch_size < 1:
if cmd_opts.distributed_debug: logger.debug("Master couldn't keep up... defaulting to 1 image")
print("Master couldn't keep up... defaulting to 1 image")
master_job.batch_size = 1 master_job.batch_size = 1
print("After job optimization, job layout is the following:") logger.info("After job optimization, job layout is the following:")
for job in self.jobs: for job in self.jobs:
print(f"worker '{job.worker.uuid}' - {job.batch_size} images") logger.info(f"worker '{job.worker.uuid}' - {job.batch_size} images")
print() print()

View File

@ -1,3 +1,11 @@
import logging
from rich.logging import RichHandler
from modules.shared import cmd_opts
log_level = 'DEBUG' if cmd_opts.distributed_debug else 'INFO'
logging.basicConfig(level=log_level, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()])
logger = logging.getLogger("rich")
benchmark_payload: dict = { benchmark_payload: dict = {
"prompt": "A herd of cows grazing at the bottom of a sunny valley", "prompt": "A herd of cows grazing at the bottom of a sunny valley",
"negative_prompt": "", "negative_prompt": "",