improve logging by using rich
parent
bdf758f1c4
commit
34bf2893e4
|
|
@ -0,0 +1,4 @@
|
|||
import launch
|
||||
|
||||
if not launch.is_installed("rich"):
|
||||
launch.run_pip("install rich", "requirements for distributed")
|
||||
|
|
@ -23,6 +23,7 @@ import subprocess
|
|||
from scripts.spartan.World import World, NotBenchmarked, WorldAlreadyInitialized
|
||||
from scripts.spartan.Worker import Worker, State
|
||||
from modules.shared import opts
|
||||
from scripts.spartan.shared import logger
|
||||
|
||||
# TODO implement SSDP advertisement of some sort in sdwui api to allow extension to automatically discover workers?
|
||||
# TODO see if the current api has some sort of UUID generation functionality.
|
||||
|
|
@ -66,7 +67,7 @@ class Script(scripts.Script):
|
|||
refresh_status_btn.style(size='sm')
|
||||
refresh_status_btn.click(Script.ui_connect_status, inputs=[], outputs=[jobs, status])
|
||||
|
||||
status_tab.select(fn=Script.ui_connect_status, inputs=[], outputs=[jobs, status])
|
||||
# status_tab.select(fn=Script.ui_connect_status, inputs=[], outputs=[jobs, status])
|
||||
|
||||
with gradio.Tab('Utils'):
|
||||
refresh_checkpoints_btn = gradio.Button(value='Refresh checkpoints')
|
||||
|
|
@ -92,7 +93,7 @@ class Script(scripts.Script):
|
|||
|
||||
@staticmethod
|
||||
def ui_connect_benchmark_button():
|
||||
print("Redoing benchmarks...")
|
||||
logger.info("Redoing benchmarks...")
|
||||
Script.world.benchmark(rebenchmark=True)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -124,14 +125,14 @@ class Script(scripts.Script):
|
|||
try:
|
||||
Script.world.interrupt_remotes()
|
||||
except AttributeError:
|
||||
print("Nothing to interrupt, Distributed system not initialized")
|
||||
logger.debug("Nothing to interrupt, Distributed system not initialized")
|
||||
|
||||
@staticmethod
|
||||
def ui_connect_refresh_ckpts_btn():
|
||||
try:
|
||||
Script.world.refresh_checkpoints()
|
||||
except AttributeError:
|
||||
print("Distributed system not initialized")
|
||||
logger.debug("Distributed system not initialized")
|
||||
|
||||
@staticmethod
|
||||
def ui_connect_status():
|
||||
|
|
@ -153,7 +154,6 @@ class Script(scripts.Script):
|
|||
|
||||
# init system if it isn't already
|
||||
except AttributeError as e:
|
||||
print(e)
|
||||
# batch size will be clobbered later once an actual request is made anyway
|
||||
Script.initialize(initial_payload=None)
|
||||
return 'refresh!', 'refresh!'
|
||||
|
|
@ -165,7 +165,7 @@ class Script(scripts.Script):
|
|||
|
||||
# get master ipm by estimating based on worker speed
|
||||
master_elapsed = time.time() - Script.master_start
|
||||
print(f"Took master {master_elapsed}s")
|
||||
logger.debug(f"Took master {master_elapsed}s")
|
||||
|
||||
# wait for response from all workers
|
||||
for thread in Script.worker_threads:
|
||||
|
|
@ -177,7 +177,7 @@ class Script(scripts.Script):
|
|||
images: json = worker.response["images"]
|
||||
except TypeError:
|
||||
if worker.master is False:
|
||||
print(f"Worker '{worker.uuid}' had nothing")
|
||||
logger.debug(f"Worker '{worker.uuid}' had nothing")
|
||||
continue
|
||||
|
||||
image_params: json = worker.response["parameters"]
|
||||
|
|
@ -250,7 +250,7 @@ class Script(scripts.Script):
|
|||
|
||||
if Script.world is None:
|
||||
if Script.verify_remotes is False:
|
||||
print(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
|
||||
logger.info(f"WARNING: you have chosen to forego the verification of worker TLS certificates")
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
# construct World
|
||||
|
|
@ -264,7 +264,7 @@ class Script(scripts.Script):
|
|||
# update world or initialize and update if necessary
|
||||
try:
|
||||
Script.world.initialize(batch_size)
|
||||
print("World initialized!")
|
||||
logger.debug(f"World initialized!")
|
||||
except WorldAlreadyInitialized:
|
||||
Script.world.update_world(total_batch_size=batch_size)
|
||||
|
||||
|
|
@ -282,8 +282,6 @@ class Script(scripts.Script):
|
|||
payload = p.__dict__
|
||||
payload['batch_size'] = Script.world.get_default_worker_batch_size()
|
||||
payload['scripts'] = None
|
||||
# print(payload)
|
||||
# print(opts.dumpjson())
|
||||
|
||||
# TODO api for some reason returns 200 even if something failed to be set.
|
||||
# for now we may have to make redundant GET requests to check if actually successful...
|
||||
|
|
@ -310,7 +308,6 @@ class Script(scripts.Script):
|
|||
job.worker.loaded_model = name
|
||||
job.worker.loaded_vae = vae
|
||||
|
||||
# print(f"requesting {new_payload['batch_size']} images from worker '{job.worker.uuid}'\n")
|
||||
t = Thread(target=job.worker.request, args=(new_payload, option_payload, sync,))
|
||||
|
||||
t.start()
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from threading import Thread
|
|||
from webui import server_name
|
||||
from modules.shared import cmd_opts
|
||||
import gradio as gr
|
||||
from scripts.spartan.shared import benchmark_payload
|
||||
from scripts.spartan.shared import benchmark_payload, logger
|
||||
from enum import Enum
|
||||
|
||||
|
||||
|
|
@ -224,8 +224,8 @@ class Worker:
|
|||
else:
|
||||
eta += (eta * abs((percent_difference / 100)))
|
||||
except KeyError:
|
||||
print(f"Sampler '{payload['sampler_name']}' efficiency is not recorded.\n")
|
||||
print(f"Sampler efficiency will be treated as equivalent to Euler A.")
|
||||
logger.debug(f"Sampler '{payload['sampler_name']}' efficiency is not recorded.\n")
|
||||
logger.debug(f"Sampler efficiency will be treated as equivalent to Euler A.")
|
||||
|
||||
# TODO save and load each workers MPE before the end of session to workers.json.
|
||||
# That way initial estimations are more accurate from the second sdwui session onward
|
||||
|
|
@ -233,9 +233,8 @@ class Worker:
|
|||
if len(self.eta_percent_error) > 0:
|
||||
correction = eta * (self.eta_mpe() / 100)
|
||||
|
||||
if cmd_opts.distributed_debug:
|
||||
print(f"worker '{self.uuid}'s last ETA was off by {correction:.2f}%")
|
||||
print(f"{self.uuid} eta before correction: ", eta)
|
||||
logger.debug(f"worker '{self.uuid}'s last ETA was off by {correction:.2f}%")
|
||||
logger.debug(f"{self.uuid} eta before correction: ", eta)
|
||||
|
||||
# do regression
|
||||
if correction > 0:
|
||||
|
|
@ -243,8 +242,7 @@ class Worker:
|
|||
else:
|
||||
eta += correction
|
||||
|
||||
if cmd_opts.distributed_debug:
|
||||
print(f"{self.uuid} eta after correction: ", eta)
|
||||
logger.debug(f"{self.uuid} eta after correction: ", eta)
|
||||
|
||||
return eta
|
||||
except Exception as e:
|
||||
|
|
@ -277,7 +275,7 @@ class Worker:
|
|||
|
||||
free_vram = int(memory_response['free']) / (1024 * 1024 * 1024)
|
||||
total_vram = int(memory_response['total']) / (1024 * 1024 * 1024)
|
||||
print(f"Worker '{self.uuid}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n")
|
||||
logger.debug(f"Worker '{self.uuid}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n")
|
||||
self.free_vram = bytes(memory_response['free'])
|
||||
|
||||
if sync_options is True:
|
||||
|
|
@ -292,10 +290,29 @@ class Worker:
|
|||
|
||||
if self.benchmarked:
|
||||
eta = self.batch_eta(payload=payload)
|
||||
print(f"worker '{self.uuid}' predicts it will take {eta:.3f}s to generate {payload['batch_size']} image("
|
||||
logger.info(f"worker '{self.uuid}' predicts it will take {eta:.3f}s to generate {payload['batch_size']} image("
|
||||
f"s) at a speed of {self.avg_ipm} ipm\n")
|
||||
|
||||
try:
|
||||
# import json
|
||||
# def find_bad_keys(json_data):
|
||||
# parsed_data = json.loads(json_data)
|
||||
# bad_keys = []
|
||||
|
||||
# for key, value in parsed_data.items():
|
||||
# if isinstance(value, float):
|
||||
# if value < -1e308 or value > 1e308:
|
||||
# bad_keys.append(key)
|
||||
|
||||
# return bad_keys
|
||||
|
||||
# for key in find_bad_keys(json.dumps(payload)):
|
||||
# logger.info(f"Bad key '{key}' found in payload with value '{payload[key]}'")
|
||||
|
||||
# s_tmax can be float('inf') which is not serializable so we convert it to the max float value
|
||||
if payload['s_tmax'] == float('inf'):
|
||||
payload['s_tmax'] = 1e308
|
||||
|
||||
start = time.time()
|
||||
response = requests.post(
|
||||
self.full_url("txt2img"),
|
||||
|
|
@ -309,9 +326,8 @@ class Worker:
|
|||
self.response_time = time.time() - start
|
||||
variance = ((eta - self.response_time) / self.response_time) * 100
|
||||
|
||||
if cmd_opts.distributed_debug:
|
||||
print(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n")
|
||||
print(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
|
||||
logger.debug(f"\nWorker '{self.uuid}'s ETA was off by {variance:.2f}%.\n")
|
||||
logger.debug(f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n")
|
||||
|
||||
# if the variance is greater than 500% then we ignore it to prevent variation inflation
|
||||
if abs(variance) < 500:
|
||||
|
|
@ -324,7 +340,7 @@ class Worker:
|
|||
else: # normal case
|
||||
self.eta_percent_error.append(variance)
|
||||
else:
|
||||
print(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
|
||||
logger.debug(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
|
||||
|
||||
except Exception as e:
|
||||
if payload['batch_size'] == 0:
|
||||
|
|
@ -333,7 +349,7 @@ class Worker:
|
|||
raise InvalidWorkerResponse(e)
|
||||
|
||||
except requests.exceptions.ConnectTimeout:
|
||||
print(f"\nTimed out waiting for worker '{self.uuid}' at {self}")
|
||||
logger.info(f"\nTimed out waiting for worker '{self.uuid}' at {self}")
|
||||
|
||||
self.state = State.IDLE
|
||||
return
|
||||
|
|
@ -348,7 +364,7 @@ class Worker:
|
|||
samples = 2 # number of times to benchmark the remote / accuracy
|
||||
warmup_samples = 2 # number of samples to do before recording as a valid sample in order to "warm-up"
|
||||
|
||||
print(f"Benchmarking worker '{self.uuid}':\n")
|
||||
logger.info(f"Benchmarking worker '{self.uuid}':\n")
|
||||
|
||||
def ipm(seconds: float) -> float:
|
||||
"""
|
||||
|
|
@ -376,15 +392,16 @@ class Worker:
|
|||
elapsed = time.time() - start
|
||||
sample_ipm = ipm(elapsed)
|
||||
except InvalidWorkerResponse as e:
|
||||
# TODO
|
||||
print(e)
|
||||
raise gr.Error(e.__str__())
|
||||
|
||||
if i >= warmup_samples:
|
||||
print(f"Sample {i - warmup_samples + 1}: Worker '{self.uuid}'({self}) - {sample_ipm:.2f} image(s) per "
|
||||
logger.info(f"Sample {i - warmup_samples + 1}: Worker '{self.uuid}'({self}) - {sample_ipm:.2f} image(s) per "
|
||||
f"minute\n")
|
||||
results.append(sample_ipm)
|
||||
elif i == warmup_samples - 1:
|
||||
print(f"{self.uuid} warming up\n")
|
||||
logger.info(f"{self.uuid} warming up\n")
|
||||
|
||||
# average the sample results for accuracy
|
||||
ipm_sum = 0
|
||||
|
|
@ -392,7 +409,7 @@ class Worker:
|
|||
ipm_sum += ipm
|
||||
avg_ipm = math.floor(ipm_sum / samples)
|
||||
|
||||
print(f"Worker '{self.uuid}' average ipm: {avg_ipm}")
|
||||
logger.info(f"Worker '{self.uuid}' average ipm: {avg_ipm}")
|
||||
self.avg_ipm = avg_ipm
|
||||
# noinspection PyTypeChecker
|
||||
self.response = None
|
||||
|
|
@ -410,8 +427,7 @@ class Worker:
|
|||
|
||||
if response.status_code == 200:
|
||||
self.state = State.INTERRUPTED
|
||||
if cmd_opts.distributed_debug:
|
||||
print(f"successfully refreshed checkpoints for worker '{self.uuid}'")
|
||||
logger.debug(f"successfully refreshed checkpoints for worker '{self.uuid}'")
|
||||
|
||||
def interrupt(self):
|
||||
response = requests.post(
|
||||
|
|
@ -422,5 +438,4 @@ class Worker:
|
|||
|
||||
if response.status_code == 200:
|
||||
self.state = State.INTERRUPTED
|
||||
if cmd_opts.distributed_debug:
|
||||
print(f"successfully interrupted worker {self.uuid}")
|
||||
logger.debug(f"successfully interrupted worker {self.uuid}")
|
||||
|
|
|
|||
|
|
@ -15,9 +15,9 @@ from inspect import getsourcefile
|
|||
from os.path import abspath
|
||||
from pathlib import Path
|
||||
from modules.processing import process_images
|
||||
from modules.shared import cmd_opts
|
||||
# from modules.shared import cmd_opts
|
||||
from scripts.spartan.Worker import Worker
|
||||
from scripts.spartan.shared import benchmark_payload
|
||||
from scripts.spartan.shared import benchmark_payload, logger
|
||||
# from modules.errors import display
|
||||
import gradio as gr
|
||||
|
||||
|
|
@ -95,8 +95,7 @@ class World:
|
|||
world_size = self.get_world_size()
|
||||
if total_batch_size < world_size:
|
||||
self.total_batch_size = world_size
|
||||
print(f"Total batch size should not be less than the number of workers.\n")
|
||||
print(f"Defaulting to a total batch size of '{world_size}' in order to accommodate all workers")
|
||||
logger.debug(f"Defaulting to a total batch size of '{world_size}' in order to accommodate all workers")
|
||||
else:
|
||||
self.total_batch_size = total_batch_size
|
||||
|
||||
|
|
@ -202,6 +201,8 @@ class World:
|
|||
"""
|
||||
global benchmark_payload
|
||||
|
||||
logger.info("Benchmarking workers...")
|
||||
|
||||
workers_info: dict = {}
|
||||
saved: bool = os.path.exists(self.worker_info_path)
|
||||
benchmark_payload_loaded: bool = False
|
||||
|
|
@ -226,9 +227,7 @@ class World:
|
|||
benchmark_payload = workers_info[worker.uuid]['benchmark_payload']
|
||||
benchmark_payload_loaded = True
|
||||
|
||||
if cmd_opts.distributed_debug:
|
||||
print("loaded saved worker configuration:")
|
||||
print(workers_info)
|
||||
logger.debug(f"loaded saved worker configuration: \n{workers_info}")
|
||||
worker.avg_ipm = workers_info[worker.uuid]['avg_ipm']
|
||||
worker.benchmarked = True
|
||||
|
||||
|
|
@ -352,7 +351,7 @@ class World:
|
|||
|
||||
ipm = benchmark_payload['batch_size'] / (elapsed / 60)
|
||||
|
||||
print(f"Master benchmark took {elapsed}: {ipm} ipm")
|
||||
logger.debug(f"Master benchmark took {elapsed}: {ipm} ipm")
|
||||
self.master().benchmarked = True
|
||||
return ipm
|
||||
|
||||
|
|
@ -393,7 +392,7 @@ class World:
|
|||
job.batch_size = payload['batch_size']
|
||||
continue
|
||||
|
||||
print(f"worker '{job.worker.uuid}' would stall the image gallery by ~{lag:.2f}s\n")
|
||||
logger.debug(f"worker '{job.worker.uuid}' would stall the image gallery by ~{lag:.2f}s\n")
|
||||
job.complementary = True
|
||||
deferred_images = deferred_images + payload['batch_size']
|
||||
job.batch_size = 0
|
||||
|
|
@ -421,8 +420,7 @@ class World:
|
|||
|
||||
slowest_active_worker = self.slowest_realtime_job().worker
|
||||
slack_time = slowest_active_worker.batch_eta(payload=payload)
|
||||
if cmd_opts.distributed_debug:
|
||||
print(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.uuid}'")
|
||||
logger.debug(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.uuid}'")
|
||||
|
||||
# in the case that this worker is now taking on what others workers would have been (if they were real-time)
|
||||
# this means that there will be more slack time for complementary nodes
|
||||
|
|
@ -440,11 +438,10 @@ class World:
|
|||
# It might be better to just inject a black image. (if master is that slow)
|
||||
master_job = self.master_job()
|
||||
if master_job.batch_size < 1:
|
||||
if cmd_opts.distributed_debug:
|
||||
print("Master couldn't keep up... defaulting to 1 image")
|
||||
logger.debug("Master couldn't keep up... defaulting to 1 image")
|
||||
master_job.batch_size = 1
|
||||
|
||||
print("After job optimization, job layout is the following:")
|
||||
logger.info("After job optimization, job layout is the following:")
|
||||
for job in self.jobs:
|
||||
print(f"worker '{job.worker.uuid}' - {job.batch_size} images")
|
||||
logger.info(f"worker '{job.worker.uuid}' - {job.batch_size} images")
|
||||
print()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,11 @@
|
|||
import logging
|
||||
from rich.logging import RichHandler
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
log_level = 'DEBUG' if cmd_opts.distributed_debug else 'INFO'
|
||||
logging.basicConfig(level=log_level, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()])
|
||||
logger = logging.getLogger("rich")
|
||||
|
||||
benchmark_payload: dict = {
|
||||
"prompt": "A herd of cows grazing at the bottom of a sunny valley",
|
||||
"negative_prompt": "",
|
||||
|
|
|
|||
Loading…
Reference in New Issue