diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index f9cc48d..b334cc6 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -300,6 +300,90 @@ class Worker: return eta + def matching_scripts(self, payload: dict, mode) -> dict: + alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example + if alwayson_scripts is not None: + if len(self.supported_scripts) <= 0: + payload['alwayson_scripts'] = {} + else: + matching_scripts = {} + missing_scripts = [] + remote_scripts = self.supported_scripts[mode] + for local_script in alwayson_scripts: + match = False + for remote_script in remote_scripts: + if str.lower(local_script) == str.lower(remote_script): + matching_scripts[local_script] = alwayson_scripts[local_script] + match = True + if not match and str.lower(local_script) != 'distribute': + missing_scripts.append(local_script) + + if len(missing_scripts) > 0: # warn about node to node script/extension mismatching + message = "local script(s): " + for script in range(0, len(missing_scripts)): + message += f"\[{missing_scripts[script]}]" + if script < len(missing_scripts) - 1: + message += ', ' + message += f" seem to be unsupported by worker '{self.label}'\n" + if LOG_LEVEL == 'DEBUG': # only warn once per session unless at debug log level + logger.debug(message) + elif self.jobs_requested < 1: + logger.warning(message) + + return matching_scripts + + def query_memory(self): + if self.queried is False: + self.queried = True + response = self.session.get(self.full_url("memory")).json() + try: + response = response['cuda']['system'] # all in bytes + free_vram = int(response['free']) / (1024 * 1024 * 1024) + total_vram = int(response['total']) / (1024 * 1024 * 1024) + logger.debug(f"Worker '{self.label}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n") + return response['free'] + except KeyError: + try: + error = response['cuda']['error'] + msg = f"CUDA seems unavailable for worker '{self.label}'\nError: {error}" + logger.warning(msg) + except KeyError: + logger.error(f"An error occurred querying memory statistics from worker '{self.label}'\n" + f"{response}") + + def prepare_payload(self, payload: dict): + # clean anything that is unserializable + s_tmax = payload.get('s_tmax', 0.0) + if s_tmax is not None and s_tmax > 1e308: # s_tmax can be float('inf') + payload['s_tmax'] = 1e308 + key_blacklist = ['cached_uc', 'cached_c', 'uc', 'c', 'cached_hr_c', 'cached_hr_uc'] + for k in key_blacklist: + payload.pop(k, None) + + # if img2img then we need to b64 encode the init images + init_images = payload.get('init_images', None) + mode = 'txt2img' + if init_images is not None: + mode = 'img2img' # for use in checking script compat + images = [] + for image in init_images: + images.append(pil_to_64(image)) + payload['init_images'] = images + # if an image mask is present + image_mask = payload.get('image_mask', None) + if image_mask is not None: + payload['mask'] = pil_to_64(image_mask) + del payload['image_mask'] + + payload['alwayson_scripts'] = self.matching_scripts(payload, mode) + # see if there is anything else wrong with serializing to payload + try: + json.dumps(payload) + except Exception: + if payload.get('init_images', None): + payload['init_images'] = 'TRUNCATED' + logger.error(f"Failed to serialize payload: \n{payload}") + def request(self, payload: dict, option_payload: dict, sync_options: bool): """ Sends an arbitrary amount of requests to a sdwui api depending on the context. @@ -310,6 +394,7 @@ class Worker: sync_options (bool): Whether to attempt to synchronize the worker's loaded models with the locals' """ eta = None + response_queue = queue.Queue() try: if self.jobs_requested != 0: # prevent potential hang at startup @@ -327,32 +412,8 @@ class Worker: logger.debug(f"waited {waited}s for worker '{self.label}' to IDLE before consecutive request") if waited >= (0.85 * max_wait): logger.warning("this seems long, so if you see this message often, consider reporting an issue") - self.set_state(State.WORKING) - - # query memory available on worker and store for future reference - if self.queried is False: - self.queried = True - memory_response = self.session.get( - self.full_url("memory") - ) - memory_response = memory_response.json() - try: - memory_response = memory_response['cuda']['system'] # all in bytes - free_vram = int(memory_response['free']) / (1024 * 1024 * 1024) - total_vram = int(memory_response['total']) / (1024 * 1024 * 1024) - logger.debug(f"Worker '{self.label}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n") - self.free_vram = memory_response['free'] - except KeyError: - try: - error = memory_response['cuda']['error'] - msg = f"CUDA seems unavailable for worker '{self.label}'\nError: {error}" - logger.warning(msg) - # gradio.Warning("Distributed: "+msg) - except KeyError: - logger.error(f"An error occurred querying memory statistics from worker '{self.label}'\n" - f"{memory_response}") - + self.free_vram = self.query_memory() if sync_options is True: self.load_options(model=option_payload['sd_model_checkpoint'], vae=option_payload['sd_vae']) @@ -362,146 +423,76 @@ class Worker: f"{payload['batch_size'] * payload['n_iter']} image(s) " f"at a speed of {self.avg_ipm:.2f} ipm\n") + self.prepare_payload(payload) - # remove anything that is not serializable - # s_tmax can be float('inf') which is not serializable, so we convert it to the max float value - s_tmax = payload.get('s_tmax', 0.0) - if s_tmax is not None and s_tmax > 1e308: - payload['s_tmax'] = 1e308 - # remove unserializable caches - payload.pop('cached_uc', None) - payload.pop('cached_c', None) - payload.pop('uc', None) - payload.pop('c', None) - payload.pop('cached_hr_c', None) - payload.pop('cached_hr_uc', None) + def interruptible_request(response_queue): + # TODO shouldn't be this way + sampler_index = payload.get('sampler_index', None) + sampler_name = payload.get('sampler_name', None) + if sampler_index is None: + if sampler_name is not None: + payload['sampler_index'] = sampler_name - # if img2img then we need to b64 encode the init images - init_images = payload.get('init_images', None) - mode = 'txt2img' - if init_images is not None: - mode = 'img2img' # for use in checking script compat - images = [] - for image in init_images: - images.append(pil_to_64(image)) - payload['init_images'] = images - - alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example - if alwayson_scripts is not None: - if len(self.supported_scripts) <= 0: - payload['alwayson_scripts'] = {} - else: - matching_scripts = {} - missing_scripts = [] - remote_scripts = self.supported_scripts[mode] - for local_script in alwayson_scripts: - match = False - for remote_script in remote_scripts: - if str.lower(local_script) == str.lower(remote_script): - matching_scripts[local_script] = alwayson_scripts[local_script] - match = True - if not match and str.lower(local_script) != 'distribute': - missing_scripts.append(local_script) - - if len(missing_scripts) > 0: # warn about node to node script/extension mismatching - message = "local script(s): " - for script in range(0, len(missing_scripts)): - message += f"\[{missing_scripts[script]}]" - if script < len(missing_scripts) - 1: - message += ', ' - message += f" seem to be unsupported by worker '{self.label}'\n" - if LOG_LEVEL == 'DEBUG': # only warn once per session unless at debug log level - logger.debug(message) - elif self.jobs_requested < 1: - logger.warning(message) - - payload['alwayson_scripts'] = matching_scripts - - # if an image mask is present - image_mask = payload.get('image_mask', None) - if image_mask is not None: - payload['mask'] = pil_to_64(image_mask) - del payload['image_mask'] - - # see if there is anything else wrong with serializing to payload try: - json.dumps(payload) - except Exception: - if payload.get('init_images', None): - payload['init_images'] = 'TRUNCATED' - logger.error(f"Failed to serialize payload: \n{payload}") - # gradio.Info("Distributed: failed to serialize payload") - - # the main api requests sent to either the txt2img or img2img route - response_queue = queue.Queue() - - def preemptible_request(response_queue): - # TODO shouldn't be this way - sampler_index = payload.get('sampler_index', None) - sampler_name = payload.get('sampler_name', None) - if sampler_index is None: - if sampler_name is not None: - payload['sampler_index'] = sampler_name - - try: - response = self.session.post( - self.full_url("txt2img") if init_images is None else self.full_url("img2img"), + response_queue.put( + self.session.post( + self.full_url("txt2img") if payload.get('init_images', None) is None else self.full_url("img2img"), json=payload ) - response_queue.put(response) - except Exception as e: - response_queue.put(e) # forwarding thrown exceptions to parent thread + ) + except Exception as e: + response_queue.put(e) # forwarding thrown exceptions to parent thread - request_thread = Thread(target=preemptible_request, args=(response_queue,)) - interrupting = False - start = time.time() - request_thread.start() - while request_thread.is_alive(): - if interrupting is False and master_state.interrupted is True: - self.interrupt() - interrupting = True - time.sleep(0.5) + request_thread = Thread(target=interruptible_request, args=(response_queue,)) + interrupting = False + start = time.time() + request_thread.start() + while request_thread.is_alive(): + if interrupting is False and master_state.interrupted is True: + self.interrupt() + interrupting = True + time.sleep(0.5) - result = response_queue.get() - if isinstance(result, Exception): - raise result - response = result + result = response_queue.get() + if isinstance(result, Exception): + raise result + response = result - self.response = response.json() - if response.status_code != 200: - # try again when remote doesn't support the selected sampler by falling back to Euler a - if response.status_code == 404 and self.response['detail'] == "Sampler not found": - logger.warning(f"falling back to Euler A sampler for worker {self.label}\n" - f"this may mean you should update this worker") - payload['sampler_index'] = 'Euler a' - payload['sampler_name'] = 'Euler a' + self.response = response.json() + if response.status_code != 200: + # try again when remote doesn't support the selected sampler by falling back to Euler a + if response.status_code == 404 and self.response['detail'] == "Sampler not found": + logger.warning(f"falling back to Euler A sampler for worker {self.label}\n" + f"this may mean you should update this worker") + payload['sampler_index'] = 'Euler a' + payload['sampler_name'] = 'Euler a' - second_attempt = Thread(target=self.request, args=(payload, option_payload, sync_options,)) - second_attempt.start() - second_attempt.join() - return + second_attempt = Thread(target=self.request, args=(payload, option_payload, sync_options,)) + second_attempt.start() + second_attempt.join() + return - self.response = None - raise WorkerException(f"bad response: Code <{response.status_code}> ", worker=self) + self.response = None + raise WorkerException(f"bad response: Code <{response.status_code}> ", worker=self) - # update list of ETA accuracy if state is valid - if self.benchmarked and not self.state == State.INTERRUPTED: - self.response_time = time.time() - start - variance = ((eta - self.response_time) / self.response_time) * 100 + # update list of ETA accuracy if state is valid + if self.benchmarked and not self.state == State.INTERRUPTED: + self.response_time = time.time() - start + variance = ((eta - self.response_time) / self.response_time) * 100 - logger.debug(f"Worker '{self.label}'s ETA was off by {variance:.2f}%\n" - f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n") + logger.debug(f"Worker '{self.label}'s ETA was off by {variance:.2f}%\n" + f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n") - # if the variance is greater than 500% then we ignore it to prevent variation inflation - if abs(variance) < 500: - # check if there are already 5 samples and if so, remove the oldest - # this should help adjust to the user changing tasks - if len(self.eta_percent_error) > 4: - self.eta_percent_error.pop(0) - else: # normal case - self.eta_percent_error.append(variance) - else: - logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n") + # if the variance is greater than 500% then we ignore it to prevent variation inflation + if abs(variance) < 500: + # check if there are already 5 samples and if so, remove the oldest + # this should help adjust to the user changing tasks + if len(self.eta_percent_error) > 4: + self.eta_percent_error.pop(0) + else: # normal case + self.eta_percent_error.append(variance) + else: + logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n") except Exception as e: raise WorkerException('', worker=self, exception=e)