improve logging by using rich

pull/2/head
papuSpartan 2023-05-18 13:13:38 -05:00
parent bdf758f1c4
commit 34bf2893e4
5 changed files with 71 additions and 50 deletions

4
install.py Normal file
View File

@ -0,0 +1,4 @@
import launch
if not launch.is_installed("rich"):
launch.run_pip("install rich", "requirements for distributed")

View File

@ -23,6 +23,7 @@ import subprocess
from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
from scripts.spartan.Worker import Worker, State
from modules.shared import opts
from scripts.spartan.shared import logger
# TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
# TODO see if the current api has some sort of UUID generation functionality.
@ -66,7 +67,7 @@ class Script(scripts.Script):
refresh_status_btn.style(size='sm')
refresh_status_btn.click(Script.ui_connect_status, inputs=[], outputs=[jobs, status])
status_tab.select(fn=Script.ui_connect_status, inputs=[], outputs=[jobs, status])
# status_tab.select(fn=Script.ui_connect_status, inputs=[], outputs=[jobs, status])
with gradio.Tab('Utils'):
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
@ -92,7 +93,7 @@ class Script(scripts.Script):
@staticmethod
def ui_connect_benchmark_button():
print("Redoing benchmarks...")
logger.info("Redoing benchmarks...")
Script.world.benchmark(rebenchmark=True)
@staticmethod
@ -124,14 +125,14 @@ class Script(scripts.Script):
try:
Script.world.interrupt_remotes()
except AttributeError:
print("Nothing to interrupt, Distributed system not initialized")
logger.debug("Nothing to interrupt, Distributed system not initialized")
@staticmethod
def ui_connect_refresh_ckpts_btn():
try:
Script.world.refresh_checkpoints()
except AttributeError:
print("Distributed system not initialized")
logger.debug("Distributed system not initialized")
@staticmethod
def ui_connect_status():
@ -153,7 +154,6 @@ class Script(scripts.Script):
# init system if it isn't already
except AttributeError as e:
print(e)
# batch size will be clobbered later once an actual request is made anyway
Script.initialize(initial_payload=None)
return 'refresh!', 'refresh!'
@ -165,7 +165,7 @@ class Script(scripts.Script):
# get master ipm by estimating based on worker speed
master_elapsed = time.time() - Script.master_start
print(f"Took master {master_elapsed}s")
logger.debug(f"Took master {master_elapsed}s")
# wait for response from all workers
for thread in Script.worker_threads:
@ -177,7 +177,7 @@ class Script(scripts.Script):
images: json = worker.response["images"]
except TypeError:
if worker.master is False:
print(f"Worker '{worker.uuid}' had nothing")
logger.debug(f"Worker '{worker.uuid}' had nothing")
continue
image_params: json = worker.response["parameters"]
@ -250,7 +250,7 @@ class Script(scripts.Script):
if Script.world is None:
if Script.verify_remotes is False:
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
logger.info(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# construct World
@ -264,7 +264,7 @@ class Script(scripts.Script):
# update world or initialize and update if necessary
try:
Script.world.initialize(batch_size)
print("World initialized!")
logger.debug(f"World initialized!")
except WorldAlreadyInitialized:
Script.world.update_world(total_batch_size=batch_size)
@ -282,8 +282,6 @@ class Script(scripts.Script):
payload = p.__dict__
payload['batch_size'] = Script.world.get_default_worker_batch_size()
payload['scripts'] = None
# print(payload)
# print(opts.dumpjson())
# TODO api for some reason returns 200 even if something failed to be set.
# for now we may have to make redundant GET requests to check if actually successful...
@ -310,7 +308,6 @@ class Script(scripts.Script):
job.worker.loaded_model = name
job.worker.loaded_vae = vae
# print(f"requesting {new_payload['batch_size']} images from worker '{job.worker.uuid}'\n")
t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync,))
t.start()

View File

@ -8,7 +8,7 @@ from threading import Thread
from webui import server_name
from modules.shared import cmd_opts
import gradio as gr
from scripts.spartan.shared import benchmark_payload
from scripts.spartan.shared import benchmark_payload, logger
from enum import Enum
@ -224,8 +224,8 @@ class Worker:
else:
eta += (eta * abs((percent_difference / 100)))
except KeyError:
print(f"Sampler '{payload['sampler_name']}' efficiency is not recorded.\n")
print(f"Sampler efficiency will be treated as equivalent to Euler A.")
logger.debug(f"Sampler '{payload['sampler_name']}' efficiency is not recorded.\n")
logger.debug(f"Sampler efficiency will be treated as equivalent to Euler A.")
# TODO save and load each workers MPE before the end of session to workers.json.
# That way initial estimations are more accurate from the second sdwui session onward
@ -233,9 +233,8 @@ class Worker:
if len(self.eta_percent_error) > 0:
correction = eta * (self.eta_mpe() / 100)
if cmd_opts.distributed_debug:
print(f"worker '{self.uuid}'s last ETA was off by {correction:.2f}%")
print(f"{self.uuid} eta before correction: ", eta)
logger.debug(f"worker '{self.uuid}'s last ETA was off by {correction:.2f}%")
logger.debug(f"{self.uuid} eta before correction: ", eta)
# do regression
if correction > 0:
@ -243,8 +242,7 @@ class Worker:
else:
eta += correction
if cmd_opts.distributed_debug:
print(f"{self.uuid} eta after correction: ", eta)
logger.debug(f"{self.uuid} eta after correction: ", eta)
return eta
except Exception as e:
@ -277,7 +275,7 @@ class Worker:
free_vram = int(memory_response['free']) / (1024 * 1024 * 1024)
total_vram = int(memory_response['total']) / (1024 * 1024 * 1024)
print(f"Worker '{self.uuid}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n")
logger.debug(f"Worker '{self.uuid}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n")
self.free_vram = bytes(memory_response['free'])
if sync_options is True:
@ -292,10 +290,29 @@ class Worker:
if self.benchmarked:
eta = self.batch_eta(payload=payload)
print(f"worker '{self.uuid}' predicts it will take {eta:.3f}s to generate {payload['batch_size']} image("
logger.info(f"worker '{self.uuid}' predicts it will take {eta:.3f}s to generate {payload['batch_size']} image("
f"s) at a speed of {self.avg_ipm} ipm\n")
try:
# import json
# def find_bad_keys(json_data):
# parsed_data = json.loads(json_data)
# bad_keys = []
# for key, value in parsed_data.items():
# if isinstance(value, float):
# if value < -1e308 or value > 1e308:
# bad_keys.append(key)
# return bad_keys
# for key in find_bad_keys(json.dumps(payload)):
# logger.info(f"Bad key '{key}' found in payload with value '{payload[key]}'")
# s_tmax can be float('inf') which is not serializable so we convert it to the max float value
if payload['s_tmax'] == float('inf'):
payload['s_tmax'] = 1e308
start = time.time()
response = requests.post(
self.full_url("txt2img"),
@ -309,9 +326,8 @@ class Worker:
self.response_time = time.time() - start
variance = ((eta - self.response_time) / self.response_time) * 100
if cmd_opts.distributed_debug:
print(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n")
print(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
logger.debug(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n")
logger.debug(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
# if the variance is greater than 500% then we ignore it to prevent variation inflation
if abs(variance) < 500:
@ -324,7 +340,7 @@ class Worker:
else: # normal case
self.eta_percent_error.append(variance)
else:
print(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
logger.debug(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
except Exception as e:
if payload['batch_size'] == 0:
@ -333,7 +349,7 @@ class Worker:
raise InvalidWorkerResponse(e)
except requests.exceptions.ConnectTimeout:
print(f"\nTimed out waiting for worker '{self.uuid}' at {self}")
logger.info(f"\nTimed out waiting for worker '{self.uuid}' at {self}")
self.state = State.IDLE
return
@ -348,7 +364,7 @@ class Worker:
samples = 2 # number of times to benchmark the remote / accuracy
warmup_samples = 2 # number of samples to do before recording as a valid sample in order to "warm-up"
print(f"Benchmarking worker '{self.uuid}':\n")
logger.info(f"Benchmarking worker '{self.uuid}':\n")
def ipm(seconds: float) -> float:
"""
@ -376,15 +392,16 @@ class Worker:
elapsed = time.time() - start
sample_ipm = ipm(elapsed)
except InvalidWorkerResponse as e:
# TODO
print(e)
raise gr.Error(e.__str__())
if i >= warmup_samples:
print(f"Sample {i - warmup_samples + 1}: Worker '{self.uuid}'({self}) - {sample_ipm:.2f} image(s) per "
logger.info(f"Sample {i - warmup_samples + 1}: Worker '{self.uuid}'({self}) - {sample_ipm:.2f} image(s) per "
f"minute\n")
results.append(sample_ipm)
elif i == warmup_samples - 1:
print(f"{self.uuid} warming up\n")
logger.info(f"{self.uuid} warming up\n")
# average the sample results for accuracy
ipm_sum = 0
@ -392,7 +409,7 @@ class Worker:
ipm_sum += ipm
avg_ipm = math.floor(ipm_sum / samples)
print(f"Worker '{self.uuid}' average ipm: {avg_ipm}")
logger.info(f"Worker '{self.uuid}' average ipm: {avg_ipm}")
self.avg_ipm = avg_ipm
# noinspection PyTypeChecker
self.response = None
@ -410,8 +427,7 @@ class Worker:
if response.status_code == 200:
self.state = State.INTERRUPTED
if cmd_opts.distributed_debug:
print(f"successfully refreshed checkpoints for worker '{self.uuid}'")
logger.debug(f"successfully refreshed checkpoints for worker '{self.uuid}'")
def interrupt(self):
response = requests.post(
@ -422,5 +438,4 @@ class Worker:
if response.status_code == 200:
self.state = State.INTERRUPTED
if cmd_opts.distributed_debug:
print(f"successfully interrupted worker {self.uuid}")
logger.debug(f"successfully interrupted worker {self.uuid}")

View File

@ -15,9 +15,9 @@ from inspect import getsourcefile
from os.path import abspath
from pathlib import Path
from modules.processing import process_images
from modules.shared import cmd_opts
# from modules.shared import cmd_opts
from scripts.spartan.Worker import Worker
from scripts.spartan.shared import benchmark_payload
from scripts.spartan.shared import benchmark_payload, logger
# from modules.errors import display
import gradio as gr
@ -95,8 +95,7 @@ class World:
world_size = self.get_world_size()
if total_batch_size < world_size:
self.total_batch_size = world_size
print(f"Total batch size should not be less than the number of workers.\n")
print(f"Defaulting to a total batch size of '{world_size}' in order to accommodate all workers")
logger.debug(f"Defaulting to a total batch size of '{world_size}' in order to accommodate all workers")
else:
self.total_batch_size = total_batch_size
@ -202,6 +201,8 @@ class World:
"""
global benchmark_payload
logger.info("Benchmarking workers...")
workers_info: dict = {}
saved: bool = os.path.exists(self.worker_info_path)
benchmark_payload_loaded: bool = False
@ -226,9 +227,7 @@ class World:
benchmark_payload = workers_info[worker.uuid]['benchmark_payload']
benchmark_payload_loaded = True
if cmd_opts.distributed_debug:
print("loaded saved worker configuration:")
print(workers_info)
logger.debug(f"loaded saved worker configuration: \n{workers_info}")
worker.avg_ipm = workers_info[worker.uuid]['avg_ipm']
worker.benchmarked = True
@ -352,7 +351,7 @@ class World:
ipm = benchmark_payload['batch_size'] / (elapsed / 60)
print(f"Master benchmark took {elapsed}: {ipm} ipm")
logger.debug(f"Master benchmark took {elapsed}: {ipm} ipm")
self.master().benchmarked = True
return ipm
@ -393,7 +392,7 @@ class World:
job.batch_size = payload['batch_size']
continue
print(f"worker '{job.worker.uuid}' would stall the image gallery by ~{lag:.2f}s\n")
logger.debug(f"worker '{job.worker.uuid}' would stall the image gallery by ~{lag:.2f}s\n")
job.complementary = True
deferred_images = deferred_images + payload['batch_size']
job.batch_size = 0
@ -421,8 +420,7 @@ class World:
slowest_active_worker = self.slowest_realtime_job().worker
slack_time = slowest_active_worker.batch_eta(payload=payload)
if cmd_opts.distributed_debug:
print(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.uuid}'")
logger.debug(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.uuid}'")
# in the case that this worker is now taking on what others workers would have been (if they were real-time)
# this means that there will be more slack time for complementary nodes
@ -440,11 +438,10 @@ class World:
# It might be better to just inject a black image. (if master is that slow)
master_job = self.master_job()
if master_job.batch_size < 1:
if cmd_opts.distributed_debug:
print("Master couldn't keep up... defaulting to 1 image")
logger.debug("Master couldn't keep up... defaulting to 1 image")
master_job.batch_size = 1
print("After job optimization, job layout is the following:")
logger.info("After job optimization, job layout is the following:")
for job in self.jobs:
print(f"worker '{job.worker.uuid}' - {job.batch_size} images")
logger.info(f"worker '{job.worker.uuid}' - {job.batch_size} images")
print()

View File

@ -1,3 +1,11 @@
import logging
from rich.logging import RichHandler
from modules.shared import cmd_opts
log_level = 'DEBUG' if cmd_opts.distributed_debug else 'INFO'
logging.basicConfig(level=log_level, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()])
logger = logging.getLogger("rich")
benchmark_payload: dict = {
"prompt": "A herd of cows grazing at the bottom of a sunny valley",
"negative_prompt": "",