exception handling and logging

dev
papuSpartan 2024-11-21 19:50:15 -06:00
parent 5dceb74137
commit f819f55743
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
4 changed files with 94 additions and 87 deletions

View File

@ -153,7 +153,7 @@ class DistributedScript(scripts.Script):
received_images = False
for job in self.world.jobs:
if job.worker.response is None or job.batch_size < 1 or job.worker.master:
if not isinstance(job.worker.response, dict) or job.batch_size < 1 or job.worker.master:
continue
try:
@ -198,6 +198,8 @@ class DistributedScript(scripts.Script):
active_adapters = []
if p.all_prompts is None:
p.all_prompts = []
if p.all_negative_prompts is None:
p.all_negative_prompts = []
is_img2img = getattr(p, 'init_images', False)
if is_img2img and self.world.enabled_i2i is False:

View File

@ -8,12 +8,13 @@ class Adapter(object):
self.script = None
def early(self, p, world, script, *args) -> bool:
"""return True to cede control back to webui"""
"""make changes before any worker request objects are created. return True to cede control back to webui"""
self.script = script
return False
def late(self, p, world, payload, *args):
"""make changes after the worker request object has been created and workloads have been manipulated"""
# payload['alwayson_scripts'] guaranteed to exist, but may not be populated
pass

View File

@ -32,8 +32,8 @@ def pack_control_net(cn_units) -> dict:
for i in range(0, len(cn_units)):
if cn_units[i].enabled:
cn_args.append(copy.deepcopy(cn_units[i].__dict__))
else:
logger.debug(f"controlnet unit {i} is not enabled (ignoring)")
# else:
# logger.debug(f"controlnet unit {i} is not enabled (ignoring)")
for i in range(0, len(cn_args)):
unit = cn_args[i]

View File

@ -26,12 +26,27 @@ except ImportError: # webui 95821f0132f5437ef30b0dbcac7c51e55818c18f and newer
from .pmodels import Worker_Model
class InvalidWorkerResponse(Exception):
"""
Should be raised when an invalid or unexpected response is received from a worker request.
"""
pass
class WorkerException(Exception):
def __init__(self, message, worker, exception=None):
super().__init__(message) # no-op?
error_msg = message
error_msg += f"\n{repr(worker)}"
if isinstance(worker.response, dict):
temp = copy.deepcopy(worker.response)
temp['images'] = f"{len(temp['images'])} image(s) (truncated)"
error_msg += f"\n\nworker.response\n{temp}"
logger.error(error_msg)
if exception is not None: # there is a nested exception
if isinstance(exception, WorkerException):
raise exception
elif isinstance(exception, requests.RequestException):
worker.set_state(State.UNAVAILABLE)
else:
worker.set_state(State.IDLE)
# logger.exception(exception)
raise exception
class State(Enum):
IDLE = 1
@ -146,7 +161,7 @@ class Worker:
if address.endswith('/'):
address = address[:-1]
else:
raise InvalidWorkerResponse("Worker address cannot be None")
raise WorkerException("Worker address cannot be None")
# auth
self.user = str(user) # casting these "prevents future issues with requests"
@ -162,7 +177,7 @@ class Worker:
return f"{self.address}:{self.port}"
def __repr__(self):
return f"'{self.label}'@{self.address}:{self.port}, speed: {self.avg_ipm} ipm, state: {self.state}"
return f"'{self.label}'@{self.address}:{self.port}, speed: {self.avg_ipm:.2f} ipm, state: {self.state}"
def __eq__(self, other):
if isinstance(other, Worker) and other.label == self.label:
@ -297,7 +312,6 @@ class Worker:
eta = None
try:
if self.jobs_requested != 0: # prevent potential hang at startup
# if state is already WORKING then weights may be loading on worker
# prevents issue where model override loads a large model and consecutive requests timeout
@ -348,60 +362,60 @@ class Worker:
f"{payload['batch_size'] * payload['n_iter']} image(s) "
f"at a speed of {self.avg_ipm:.2f} ipm\n")
try:
# remove anything that is not serializable
# s_tmax can be float('inf') which is not serializable, so we convert it to the max float value
s_tmax = payload.get('s_tmax', 0.0)
if s_tmax is not None and s_tmax > 1e308:
payload['s_tmax'] = 1e308
# remove unserializable caches
payload.pop('cached_uc', None)
payload.pop('cached_c', None)
payload.pop('uc', None)
payload.pop('c', None)
payload.pop('cached_hr_c', None)
payload.pop('cached_hr_uc', None)
# if img2img then we need to b64 encode the init images
init_images = payload.get('init_images', None)
mode = 'txt2img'
if init_images is not None:
mode = 'img2img' # for use in checking script compat
images = []
for image in init_images:
images.append(pil_to_64(image))
payload['init_images'] = images
# remove anything that is not serializable
# s_tmax can be float('inf') which is not serializable, so we convert it to the max float value
s_tmax = payload.get('s_tmax', 0.0)
if s_tmax is not None and s_tmax > 1e308:
payload['s_tmax'] = 1e308
# remove unserializable caches
payload.pop('cached_uc', None)
payload.pop('cached_c', None)
payload.pop('uc', None)
payload.pop('c', None)
payload.pop('cached_hr_c', None)
payload.pop('cached_hr_uc', None)
alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example
if alwayson_scripts is not None:
if len(self.supported_scripts) <= 0:
payload['alwayson_scripts'] = {}
else:
matching_scripts = {}
missing_scripts = []
remote_scripts = self.supported_scripts[mode]
for local_script in alwayson_scripts:
match = False
for remote_script in remote_scripts:
if str.lower(local_script) == str.lower(remote_script):
matching_scripts[local_script] = alwayson_scripts[local_script]
match = True
if not match and str.lower(local_script) != 'distribute':
missing_scripts.append(local_script)
# if img2img then we need to b64 encode the init images
init_images = payload.get('init_images', None)
mode = 'txt2img'
if init_images is not None:
mode = 'img2img' # for use in checking script compat
images = []
for image in init_images:
images.append(pil_to_64(image))
payload['init_images'] = images
if len(missing_scripts) > 0: # warn about node to node script/extension mismatching
message = "local script(s): "
for script in range(0, len(missing_scripts)):
message += f"\[{missing_scripts[script]}]"
if script < len(missing_scripts) - 1:
message += ', '
message += f" seem to be unsupported by worker '{self.label}'\n"
if LOG_LEVEL == 'DEBUG': # only warn once per session unless at debug log level
logger.debug(message)
elif self.jobs_requested < 1:
logger.warning(message)
alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example
if alwayson_scripts is not None:
if len(self.supported_scripts) <= 0:
payload['alwayson_scripts'] = {}
else:
matching_scripts = {}
missing_scripts = []
remote_scripts = self.supported_scripts[mode]
for local_script in alwayson_scripts:
match = False
for remote_script in remote_scripts:
if str.lower(local_script) == str.lower(remote_script):
matching_scripts[local_script] = alwayson_scripts[local_script]
match = True
if not match and str.lower(local_script) != 'distribute':
missing_scripts.append(local_script)
payload['alwayson_scripts'] = matching_scripts
if len(missing_scripts) > 0: # warn about node to node script/extension mismatching
message = "local script(s): "
for script in range(0, len(missing_scripts)):
message += f"\[{missing_scripts[script]}]"
if script < len(missing_scripts) - 1:
message += ', '
message += f" seem to be unsupported by worker '{self.label}'\n"
if LOG_LEVEL == 'DEBUG': # only warn once per session unless at debug log level
logger.debug(message)
elif self.jobs_requested < 1:
logger.warning(message)
payload['alwayson_scripts'] = matching_scripts
# if an image mask is present
image_mask = payload.get('image_mask', None)
@ -467,11 +481,8 @@ class Worker:
second_attempt.join()
return
logger.error(
f"'{self.label}' response: Code <{response.status_code}> "
f"{str(response.content, 'utf-8')}")
self.response = None
raise InvalidWorkerResponse()
raise WorkerException(f"bad response: Code <{response.status_code}> ", worker=self)
# update list of ETA accuracy if state is valid
if self.benchmarked and not self.state == State.INTERRUPTED:
@ -492,13 +503,8 @@ class Worker:
else:
logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n")
except Exception as e:
self.set_state(State.IDLE)
raise InvalidWorkerResponse(e)
except requests.RequestException:
self.set_state(State.UNAVAILABLE)
return
except Exception as e:
raise WorkerException('', worker=self, exception=e)
self.set_state(State.IDLE)
self.jobs_requested += 1
@ -540,22 +546,20 @@ class Worker:
if self.state == State.UNAVAILABLE:
return 0
try: # if the worker is unreachable/offline then handle that here
elapsed = None
# if the worker is unreachable/offline then handle that here
elapsed = None
if not callable(sample_function):
start = time.time()
t = Thread(target=self.request, args=(dict(sh.benchmark_payload), None, False,),
name=f"{self.label}_benchmark_request")
t.start()
t.join()
elapsed = time.time() - start
else:
elapsed = sample_function()
if not callable(sample_function):
start = time.time()
t = Thread(target=self.request, args=(dict(sh.benchmark_payload), None, False,),
name=f"{self.label}_benchmark_request")
t.start()
t.join()
elapsed = time.time() - start
else:
elapsed = sample_function()
sample_ipm = ipm(elapsed)
except InvalidWorkerResponse as e:
raise e
sample_ipm = ipm(elapsed)
if i >= warmup_samples:
logger.info(f"Sample {i - warmup_samples + 1}: Worker '{self.label}'({self}) "