diff --git a/scripts/distributed.py b/scripts/distributed.py index 4564fad..d8dc1e7 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -153,7 +153,7 @@ class DistributedScript(scripts.Script): received_images = False for job in self.world.jobs: - if job.worker.response is None or job.batch_size < 1 or job.worker.master: + if not isinstance(job.worker.response, dict) or job.batch_size < 1 or job.worker.master: continue try: @@ -198,6 +198,8 @@ class DistributedScript(scripts.Script): active_adapters = [] if p.all_prompts is None: p.all_prompts = [] + if p.all_negative_prompts is None: + p.all_negative_prompts = [] is_img2img = getattr(p, 'init_images', False) if is_img2img and self.world.enabled_i2i is False: diff --git a/scripts/spartan/adapters.py b/scripts/spartan/adapters.py index 10c744b..8871968 100644 --- a/scripts/spartan/adapters.py +++ b/scripts/spartan/adapters.py @@ -8,12 +8,13 @@ class Adapter(object): self.script = None def early(self, p, world, script, *args) -> bool: - """return True to cede control back to webui""" + """make changes before any worker request objects are created. return True to cede control back to webui""" self.script = script return False def late(self, p, world, payload, *args): + """make changes after the worker request object has been created and workloads have been manipulated""" # payload['alwayson_scripts'] guaranteed to exist, but may not be populated pass diff --git a/scripts/spartan/control_net.py b/scripts/spartan/control_net.py index a029aea..5958f8b 100644 --- a/scripts/spartan/control_net.py +++ b/scripts/spartan/control_net.py @@ -32,8 +32,8 @@ def pack_control_net(cn_units) -> dict: for i in range(0, len(cn_units)): if cn_units[i].enabled: cn_args.append(copy.deepcopy(cn_units[i].__dict__)) - else: - logger.debug(f"controlnet unit {i} is not enabled (ignoring)") + # else: + # logger.debug(f"controlnet unit {i} is not enabled (ignoring)") for i in range(0, len(cn_args)): unit = cn_args[i] diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index f9dda26..f9cc48d 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -26,12 +26,27 @@ except ImportError: # webui 95821f0132f5437ef30b0dbcac7c51e55818c18f and newer from .pmodels import Worker_Model -class InvalidWorkerResponse(Exception): - """ - Should be raised when an invalid or unexpected response is received from a worker request. - """ - pass +class WorkerException(Exception): + def __init__(self, message, worker, exception=None): + super().__init__(message) # no-op? + error_msg = message + error_msg += f"\n{repr(worker)}" + if isinstance(worker.response, dict): + temp = copy.deepcopy(worker.response) + temp['images'] = f"{len(temp['images'])} image(s) (truncated)" + error_msg += f"\n\nworker.response\n{temp}" + logger.error(error_msg) + if exception is not None: # there is a nested exception + if isinstance(exception, WorkerException): + raise exception + elif isinstance(exception, requests.RequestException): + worker.set_state(State.UNAVAILABLE) + else: + worker.set_state(State.IDLE) + + # logger.exception(exception) + raise exception class State(Enum): IDLE = 1 @@ -146,7 +161,7 @@ class Worker: if address.endswith('/'): address = address[:-1] else: - raise InvalidWorkerResponse("Worker address cannot be None") + raise WorkerException("Worker address cannot be None") # auth self.user = str(user) # casting these "prevents future issues with requests" @@ -162,7 +177,7 @@ class Worker: return f"{self.address}:{self.port}" def __repr__(self): - return f"'{self.label}'@{self.address}:{self.port}, speed: {self.avg_ipm} ipm, state: {self.state}" + return f"'{self.label}'@{self.address}:{self.port}, speed: {self.avg_ipm:.2f} ipm, state: {self.state}" def __eq__(self, other): if isinstance(other, Worker) and other.label == self.label: @@ -297,7 +312,6 @@ class Worker: eta = None try: - if self.jobs_requested != 0: # prevent potential hang at startup # if state is already WORKING then weights may be loading on worker # prevents issue where model override loads a large model and consecutive requests timeout @@ -348,60 +362,60 @@ class Worker: f"{payload['batch_size'] * payload['n_iter']} image(s) " f"at a speed of {self.avg_ipm:.2f} ipm\n") - try: - # remove anything that is not serializable - # s_tmax can be float('inf') which is not serializable, so we convert it to the max float value - s_tmax = payload.get('s_tmax', 0.0) - if s_tmax is not None and s_tmax > 1e308: - payload['s_tmax'] = 1e308 - # remove unserializable caches - payload.pop('cached_uc', None) - payload.pop('cached_c', None) - payload.pop('uc', None) - payload.pop('c', None) - payload.pop('cached_hr_c', None) - payload.pop('cached_hr_uc', None) - # if img2img then we need to b64 encode the init images - init_images = payload.get('init_images', None) - mode = 'txt2img' - if init_images is not None: - mode = 'img2img' # for use in checking script compat - images = [] - for image in init_images: - images.append(pil_to_64(image)) - payload['init_images'] = images + # remove anything that is not serializable + # s_tmax can be float('inf') which is not serializable, so we convert it to the max float value + s_tmax = payload.get('s_tmax', 0.0) + if s_tmax is not None and s_tmax > 1e308: + payload['s_tmax'] = 1e308 + # remove unserializable caches + payload.pop('cached_uc', None) + payload.pop('cached_c', None) + payload.pop('uc', None) + payload.pop('c', None) + payload.pop('cached_hr_c', None) + payload.pop('cached_hr_uc', None) - alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example - if alwayson_scripts is not None: - if len(self.supported_scripts) <= 0: - payload['alwayson_scripts'] = {} - else: - matching_scripts = {} - missing_scripts = [] - remote_scripts = self.supported_scripts[mode] - for local_script in alwayson_scripts: - match = False - for remote_script in remote_scripts: - if str.lower(local_script) == str.lower(remote_script): - matching_scripts[local_script] = alwayson_scripts[local_script] - match = True - if not match and str.lower(local_script) != 'distribute': - missing_scripts.append(local_script) + # if img2img then we need to b64 encode the init images + init_images = payload.get('init_images', None) + mode = 'txt2img' + if init_images is not None: + mode = 'img2img' # for use in checking script compat + images = [] + for image in init_images: + images.append(pil_to_64(image)) + payload['init_images'] = images - if len(missing_scripts) > 0: # warn about node to node script/extension mismatching - message = "local script(s): " - for script in range(0, len(missing_scripts)): - message += f"\[{missing_scripts[script]}]" - if script < len(missing_scripts) - 1: - message += ', ' - message += f" seem to be unsupported by worker '{self.label}'\n" - if LOG_LEVEL == 'DEBUG': # only warn once per session unless at debug log level - logger.debug(message) - elif self.jobs_requested < 1: - logger.warning(message) + alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example + if alwayson_scripts is not None: + if len(self.supported_scripts) <= 0: + payload['alwayson_scripts'] = {} + else: + matching_scripts = {} + missing_scripts = [] + remote_scripts = self.supported_scripts[mode] + for local_script in alwayson_scripts: + match = False + for remote_script in remote_scripts: + if str.lower(local_script) == str.lower(remote_script): + matching_scripts[local_script] = alwayson_scripts[local_script] + match = True + if not match and str.lower(local_script) != 'distribute': + missing_scripts.append(local_script) - payload['alwayson_scripts'] = matching_scripts + if len(missing_scripts) > 0: # warn about node to node script/extension mismatching + message = "local script(s): " + for script in range(0, len(missing_scripts)): + message += f"\[{missing_scripts[script]}]" + if script < len(missing_scripts) - 1: + message += ', ' + message += f" seem to be unsupported by worker '{self.label}'\n" + if LOG_LEVEL == 'DEBUG': # only warn once per session unless at debug log level + logger.debug(message) + elif self.jobs_requested < 1: + logger.warning(message) + + payload['alwayson_scripts'] = matching_scripts # if an image mask is present image_mask = payload.get('image_mask', None) @@ -467,11 +481,8 @@ class Worker: second_attempt.join() return - logger.error( - f"'{self.label}' response: Code <{response.status_code}> " - f"{str(response.content, 'utf-8')}") self.response = None - raise InvalidWorkerResponse() + raise WorkerException(f"bad response: Code <{response.status_code}> ", worker=self) # update list of ETA accuracy if state is valid if self.benchmarked and not self.state == State.INTERRUPTED: @@ -492,13 +503,8 @@ class Worker: else: logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n") - except Exception as e: - self.set_state(State.IDLE) - raise InvalidWorkerResponse(e) - - except requests.RequestException: - self.set_state(State.UNAVAILABLE) - return + except Exception as e: + raise WorkerException('', worker=self, exception=e) self.set_state(State.IDLE) self.jobs_requested += 1 @@ -540,22 +546,20 @@ class Worker: if self.state == State.UNAVAILABLE: return 0 - try: # if the worker is unreachable/offline then handle that here - elapsed = None + # if the worker is unreachable/offline then handle that here + elapsed = None - if not callable(sample_function): - start = time.time() - t = Thread(target=self.request, args=(dict(sh.benchmark_payload), None, False,), - name=f"{self.label}_benchmark_request") - t.start() - t.join() - elapsed = time.time() - start - else: - elapsed = sample_function() + if not callable(sample_function): + start = time.time() + t = Thread(target=self.request, args=(dict(sh.benchmark_payload), None, False,), + name=f"{self.label}_benchmark_request") + t.start() + t.join() + elapsed = time.time() - start + else: + elapsed = sample_function() - sample_ipm = ipm(elapsed) - except InvalidWorkerResponse as e: - raise e + sample_ipm = ipm(elapsed) if i >= warmup_samples: logger.info(f"Sample {i - warmup_samples + 1}: Worker '{self.label}'({self}) "