import asyncio import base64 import copy import io import json import math import queue import re import time from enum import Enum from threading import Thread from typing import List, Union import requests from modules.api.api import encode_pil_to_base64 from modules.shared import cmd_opts from modules.shared import state as master_state from . import shared as sh from .shared import logger, warmup_samples, LOG_LEVEL try: from webui import server_name except ImportError: # webui 95821f0132f5437ef30b0dbcac7c51e55818c18f and newer from modules.initialize_util import gradio_server_name server_name = gradio_server_name() from .pmodels import Worker_Model class InvalidWorkerResponse(Exception): """ Should be raised when an invalid or unexpected response is received from a worker request. """ pass class State(Enum): IDLE = 1 WORKING = 2 INTERRUPTED = 3 UNAVAILABLE = 4 DISABLED = 5 class Worker: """ This class represents a worker node in a distributed computing setup. Attributes: address (str): The address of the worker node. Can be an ip or a FQDN. Defaults to None. port (int): The port number used by the worker node. Defaults to None. avg_ipm (int): The average images per minute of the node. Defaults to None. label (str): The name of the worker node. Defaults to None. queried (bool): Whether this worker's memory status has been polled yet. Defaults to False. verify_remotes (bool): Whether to verify the validity of remote worker certificates. Defaults to False. master (bool): Whether this worker is the master node. Defaults to False. auth (str|None): The username and password used to authenticate with the worker. Defaults to None. (username:password) benchmarked (bool): Whether this worker has been benchmarked. Defaults to False. eta_percent_error (List[float]): A runtime list of ETA percent errors for this worker. Empty by default response (requests.Response): The last response from this worker. Defaults to None. Raises: InvalidWorkerResponse: If the worker responds with an invalid or unexpected response. """ # Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A. # We compare to euler a because that is what we currently benchmark each node with. other_to_euler_a = { "DPM++ 2S a Karras": -45.87, "Euler": 4.92, "LMS": 12.66, "Heun": -40.24, "DPM2": -42.50, "DPM2 a": -46.60, "DPM++ 2S a": -37.10, "DPM++ 2M": 7.46, "DPM++ SDE": -39.45, "DPM fast": 15.54, "DPM adaptive": -61.40, "LMS Karras": 5, "DPM2 Karras": -41, "DPM2 a Karras": -38.81, "DPM++ 2M Karras": 16.20, "DPM++ SDE Karras": -39.71, "DDIM": 0, "PLMS": 9.31 } def __init__(self, address: Union[str, None] = None, port: int = 7860, label: Union[str, None] = None, verify_remotes: bool = True, master: bool = False, tls: bool = False, state: State = State.IDLE, avg_ipm: float = 0.0, eta_percent_error=None, user: str = None, password: str = None, pixel_cap: int = -1 ): if eta_percent_error is None: self.eta_percent_error = [] else: self.eta_percent_error = eta_percent_error self.avg_ipm = avg_ipm self.state = state if type(state) is State else State(state) self.address = address self.port = port self.response_time = None self.loaded_model = '' self.loaded_vae = '' self.supported_scripts = {} self.label = label self.tls = tls self.model_override: Union[str, None] = None self.free_vram: int = 0 self.response = None self.queried = False self.benchmarked = False self.pixel_cap = pixel_cap # ex. limit, 2 512x512 images at once: (2*(512*512)) = 524288 px self.jobs_requested = 0 # master specific setup if master is True: self.master = master self.label = 'master' # right now this is really only for clarity while debugging: self.address = server_name if server_name is not None else 'localhost' if cmd_opts.port is None: self.port = 7860 else: self.port = cmd_opts.port return else: self.master = False # strip http:// or https:// from address if present if address is not None: if address.startswith("http://"): address = address[7:] elif address.startswith("https://"): address = address[8:] self.tls = True self.port = 443 if address.endswith('/'): address = address[:-1] else: raise InvalidWorkerResponse("Worker address cannot be None") # auth self.user = str(user) # casting these "prevents future issues with requests" self.password = str(password) # requests session self.session = requests.Session() self.session.auth = (self.user, self.password) # sometimes breaks: https://github.com/psf/requests/issues/2255 self.session.verify = verify_remotes def __str__(self): return f"{self.address}:{self.port}" def __repr__(self): return f"'{self.label}'@{self.address}:{self.port}, speed: {self.avg_ipm} ipm, state: {self.state}" def __eq__(self, other): if isinstance(other, Worker) and other.label == self.label: return True return False @property def model(self) -> Worker_Model: return Worker_Model(**self.__dict__) def eta_mpe(self): """ Returns the mean percent error using all the currently stored eta percent errors. Returns: mpe (float): The mean percent error of a worker's calculation estimates. """ if len(self.eta_percent_error) == 0: return 0 this_sum = 0 for percent in self.eta_percent_error: this_sum += percent mpe = this_sum / len(self.eta_percent_error) return mpe def full_url(self, route: str) -> str: """ Gets the full url used for making requests of sdwui at a given route. Args: route (str): The sdwui api route to send the request to. Returns: str: The full url. """ protocol = 'http' if not self.tls else 'https' return f"{protocol}://{self.__str__()}/sdapi/v1/{route}" def batch_eta_hr(self, payload: dict) -> float: """ takes a normal payload and returns the eta of a pseudo payload which mirrors the hr-fix parameters This returns the eta of how long it would take to run hr-fix on the original image """ pseudo_payload = copy.copy(payload) pseudo_payload['enable_hr'] = False # prevent overflow in self.batch_eta res_ratio = pseudo_payload['hr_scale'] original_steps = pseudo_payload['steps'] second_pass_steps = pseudo_payload['hr_second_pass_steps'] # if hires steps is set to zero then pseudo steps should = orig steps if second_pass_steps == 0: pseudo_payload['steps'] = original_steps else: pseudo_payload['steps'] = second_pass_steps pseudo_width = math.floor(pseudo_payload['width'] * res_ratio) pseudo_height = math.floor(pseudo_payload['height'] * res_ratio) pseudo_payload['width'] = pseudo_width pseudo_payload['height'] = pseudo_height eta = self.batch_eta(payload=pseudo_payload, quiet=True) return eta def batch_eta(self, payload: dict, quiet: bool = False, batch_size: int = None) -> float: """ estimate how long it will take to generate images on a worker in seconds Args: payload: Sdwui api formatted payload quiet: Whether to print error correction information batch_size: Overrides the batch_size parameter of the payload """ steps = payload['steps'] num_images = payload['batch_size'] if batch_size is None else batch_size # if worker has not yet been benchmarked then eta = (num_images / self.avg_ipm) * 60 # show effect of increased step size real_steps_to_benched = steps / sh.benchmark_payload.steps eta = eta * real_steps_to_benched # show effect of high-res fix hr = payload.get('enable_hr', False) if hr: eta += self.batch_eta_hr(payload=payload) # show effect of image size real_pix_to_benched = (payload['width'] * payload['height']) \ / (sh.benchmark_payload.width * sh.benchmark_payload.height) eta = eta * real_pix_to_benched # show effect of using a sampler other than euler a sampler = payload.get('sampler_name', 'Euler a') if sampler != 'Euler a': try: percent_difference = self.other_to_euler_a[payload['sampler_name']] if percent_difference > 0: eta -= (eta * abs((percent_difference / 100))) else: eta += (eta * abs((percent_difference / 100))) except KeyError: logger.warning(f"Efficiency of sampler '{payload['sampler_name']}' has not been recorded.\n") # in this case the sampler will be treated as having the same efficiency as Euler a # adjust for a known inaccuracy in our estimation of this worker using average percent error if len(self.eta_percent_error) > 0: correction = eta * (self.eta_mpe() / 100) if not quiet: logger.debug(f"worker '{self.label}'s last ETA was off by {correction:.2f}%") correction_summary = f"correcting '{self.label}'s ETA: {eta:.2f}s -> " # do regression eta -= correction if not quiet: correction_summary += f"{eta:.2f}s" logger.debug(correction_summary) return eta def request(self, payload: dict, option_payload: dict, sync_options: bool): """ Sends an arbitrary amount of requests to a sdwui api depending on the context. Args: payload (dict): The txt2img payload. option_payload (dict): The options payload. sync_options (bool): Whether to attempt to synchronize the worker's loaded models with the locals' """ eta = None try: if self.jobs_requested != 0: # prevent potential hang at startup # if state is already WORKING then weights may be loading on worker # prevents issue where model override loads a large model and consecutive requests timeout max_wait = 30 waited = 0 while self.state == State.WORKING: if waited >= max_wait: break time.sleep(1) waited += 1 if waited != 0: logger.debug(f"waited {waited}s for worker '{self.label}' to IDLE before consecutive request") if waited >= (0.85 * max_wait): logger.warning("this seems long, so if you see this message often, consider reporting an issue") self.state = State.WORKING # query memory available on worker and store for future reference if self.queried is False: self.queried = True memory_response = self.session.get( self.full_url("memory") ) memory_response = memory_response.json() try: memory_response = memory_response['cuda']['system'] # all in bytes free_vram = int(memory_response['free']) / (1024 * 1024 * 1024) total_vram = int(memory_response['total']) / (1024 * 1024 * 1024) logger.debug(f"Worker '{self.label}' {free_vram:.2f}/{total_vram:.2f} GB VRAM free\n") self.free_vram = memory_response['free'] except KeyError: try: error = memory_response['cuda']['error'] msg = f"CUDA seems unavailable for worker '{self.label}'\nError: {error}" logger.warning(msg) # gradio.Warning("Distributed: "+msg) except KeyError: logger.error(f"An error occurred querying memory statistics from worker '{self.label}'\n" f"{memory_response}") if sync_options is True: self.load_options(model=option_payload['sd_model_checkpoint'], vae=option_payload['sd_vae']) if self.benchmarked: eta = self.batch_eta(payload=payload) * payload['n_iter'] logger.debug(f"worker '{self.label}' predicts it will take {eta:.3f}s to generate " f"{payload['batch_size'] * payload['n_iter']} image(s) " f"at a speed of {self.avg_ipm:.2f} ipm\n") try: # remove anything that is not serializable # s_tmax can be float('inf') which is not serializable, so we convert it to the max float value s_tmax = payload.get('s_tmax', 0.0) if s_tmax > 1e308: payload['s_tmax'] = 1e308 # remove unserializable caches payload.pop('cached_uc', None) payload.pop('cached_c', None) payload.pop('uc', None) payload.pop('c', None) payload.pop('cached_hr_c', None) payload.pop('cached_hr_uc', None) # if img2img then we need to b64 encode the init images init_images = payload.get('init_images', None) mode = 'txt2img' if init_images is not None: mode = 'img2img' # for use in checking script compat images = [] for image in init_images: buffer = io.BytesIO() image.save(buffer, format="PNG") image = 'data:image/png;base64,' + str(base64.b64encode(buffer.getvalue()), 'utf-8') images.append(image) payload['init_images'] = images alwayson_scripts = payload.get('alwayson_scripts', None) # key may not always exist, benchmarking being one example if alwayson_scripts is not None: if len(self.supported_scripts) <= 0: payload['alwayson_scripts'] = {} else: matching_scripts = {} missing_scripts = [] remote_scripts = self.supported_scripts[mode] for local_script in alwayson_scripts: match = False for remote_script in remote_scripts: if str.lower(local_script) == str.lower(remote_script): matching_scripts[local_script] = alwayson_scripts[local_script] match = True if not match and str.lower(local_script) != 'distribute': missing_scripts.append(local_script) if len(missing_scripts) > 0: # warn about node to node script/extension mismatching message = "local script(s): " for script in range(0, len(missing_scripts)): message += f"\[{missing_scripts[script]}]" if script < len(missing_scripts) - 1: message += ', ' message += f" seem to be unsupported by worker '{self.label}'\n" if LOG_LEVEL == 'DEBUG': # only warn once per session unless at debug log level logger.debug(message) elif self.jobs_requested < 1: logger.warning(message) payload['alwayson_scripts'] = matching_scripts # if an image mask is present image_mask = payload.get('image_mask', None) if image_mask is not None: image_b64 = encode_pil_to_base64(image_mask) image_b64 = str(image_b64, 'utf-8') payload['mask'] = image_b64 del payload['image_mask'] # see if there is anything else wrong with serializing to payload try: json.dumps(payload) except Exception as e: logger.error(f"Failed to serialize payload: \n{payload}") # gradio.Info("Distributed: failed to serialize payload") raise e # the main api requests sent to either the txt2img or img2img route response_queue = queue.Queue() def preemptible_request(response_queue): # TODO shouldn't be this way sampler_index = payload.get('sampler_index', None) sampler_name = payload.get('sampler_name', None) if sampler_index is None: if sampler_name is not None: logger.debug("had to substitute sampler index with name") payload['sampler_index'] = sampler_name try: response = self.session.post( self.full_url("txt2img") if init_images is None else self.full_url("img2img"), json=payload ) response_queue.put(response) except Exception as e: response_queue.put(e) # forwarding thrown exceptions to parent thread request_thread = Thread(target=preemptible_request, args=(response_queue,)) interrupting = False start = time.time() request_thread.start() while request_thread.is_alive(): if interrupting is False and master_state.interrupted is True: self.interrupt() interrupting = True time.sleep(0.5) result = response_queue.get() if isinstance(result, Exception): raise result response = result self.response = response.json() if response.status_code != 200: # try again when remote doesn't support the selected sampler by falling back to Euler a if response.status_code == 404 and self.response['detail'] == "Sampler not found": logger.warning(f"falling back to Euler A sampler for worker {self.label}\n" f"this may mean you should update this worker") payload['sampler_index'] = 'Euler a' payload['sampler_name'] = 'Euler a' second_attempt = Thread(target=self.request, args=(payload, option_payload, sync_options,)) second_attempt.start() second_attempt.join() return logger.error( f"'{self.label}' response: Code <{response.status_code}> " f"{str(response.content, 'utf-8')}") self.response = None raise InvalidWorkerResponse() # update list of ETA accuracy if state is valid if self.benchmarked and not self.state == State.INTERRUPTED: self.response_time = time.time() - start variance = ((eta - self.response_time) / self.response_time) * 100 logger.debug(f"Worker '{self.label}'s ETA was off by {variance:.2f}%\n" f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n") # if the variance is greater than 500% then we ignore it to prevent variation inflation if abs(variance) < 500: # check if there are already 5 samples and if so, remove the oldest # this should help adjust to the user changing tasks if len(self.eta_percent_error) > 4: self.eta_percent_error.pop(0) else: # normal case self.eta_percent_error.append(variance) else: logger.warning(f"Variance of {variance:.2f}% exceeds threshold of 500%. Ignoring...\n") except Exception as e: self.state = State.IDLE if payload['batch_size'] == 0: raise InvalidWorkerResponse("Tried to request a null amount of images") else: raise InvalidWorkerResponse(e) except requests.RequestException: self.mark_unreachable() return self.state = State.IDLE self.jobs_requested += 1 return def benchmark(self) -> float: """ given a worker, run a small benchmark and return its performance in images/minute makes standard request(s) of 512x512 images and averages them to get the result """ t: Thread samples = 2 # number of times to benchmark the remote / accuracy if self.state in (State.DISABLED, State.UNAVAILABLE): logger.debug(f"worker '{self.label}' is unavailable or disabled, refusing to benchmark") return 0 if self.master is True: return -1 def ipm(seconds: float) -> float: """ Determines the rate of images per minute. Args: seconds (float): How many seconds it took to generate benchmark_payload['batch_size'] amount of images. Returns: float: Images per minute """ return sh.benchmark_payload.batch_size / (seconds / 60) results: List[float] = [] # it used to be lower for the first couple of generations # this was due to something torch does at startup according to auto and is now done at sdwui startup for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up" if self.state == State.UNAVAILABLE: self.response = None return 0 t = Thread(target=self.request, args=(dict(sh.benchmark_payload), None, False,), name=f"{self.label}_benchmark_request") try: # if the worker is unreachable/offline then handle that here t.start() start = time.time() t.join() elapsed = time.time() - start sample_ipm = ipm(elapsed) except InvalidWorkerResponse as e: raise e if i >= warmup_samples: logger.info(f"Sample {i - warmup_samples + 1}: Worker '{self.label}'({self}) " f"- {sample_ipm:.2f} image(s) per minute\n") results.append(sample_ipm) elif i == warmup_samples - 1: logger.debug(f"{self.label} finished warming up\n") # average the sample results for accuracy ipm_sum = 0 for ipm_result in results: ipm_sum += ipm_result avg_ipm_result = ipm_sum / samples logger.debug(f"Worker '{self.label}' average ipm: {avg_ipm_result:.2f}") self.avg_ipm = avg_ipm_result self.response = None self.benchmarked = True self.state = State.IDLE return avg_ipm_result def refresh_checkpoints(self): # gradio.Info("refreshing checkpoints") try: model_response = self.session.post(self.full_url('refresh-checkpoints')) lora_response = self.session.post(self.full_url('refresh-loras')) if model_response.status_code != 200: msg = f"Failed to refresh models for worker '{self.label}'\nCode <{model_response.status_code}>" logger.error(msg) # gradio.Warning("Distributed: "+msg) if lora_response.status_code != 200: msg = f"Failed to refresh LORA's for worker '{self.label}'\nCode <{lora_response.status_code}>" logger.error(msg) # gradio.Warning("Distributed: "+msg) except requests.exceptions.ConnectionError: self.mark_unreachable() def interrupt(self): try: response = self.session.post(self.full_url('interrupt')) if response.status_code == 200: self.state = State.INTERRUPTED logger.debug(f"successfully interrupted worker {self.label}") except requests.exceptions.ConnectionError: self.mark_unreachable() def reachable(self) -> bool: """returns false if worker is unreachable""" try: response = self.session.get( self.full_url("memory"), timeout=3 ) return response.status_code == 200 except requests.exceptions.ConnectionError as e: logger.error(e) return False def mark_unreachable(self): if self.state == State.DISABLED: logger.debug(f"worker '{self.label}' is disabled... refusing to mark as unavailable") else: msg = f"worker '{self.label}' at {self} was unreachable and will be avoided until reconnection" logger.error(msg) # gradio.Warning("Distributed: "+msg) self.state = State.UNAVAILABLE # invalidate models cache so that if/when worker reconnects, a new POST is sent to resync loaded models self.loaded_model = None self.loaded_vae = None def available_models(self) -> [List[str]]: if self.state == State.UNAVAILABLE or self.state == State.DISABLED or self.master: return [] url = self.full_url('sd-models') try: response = self.session.get( url=url, timeout=5 ) if response.status_code != 200: logger.error(f"request to {url} returned {response.status_code}") if response.status_code == 404: logger.error(f"did you enable --api for '{self.label}'?") return [] titles = [model['title'] for model in response.json()] return titles except requests.RequestException: self.mark_unreachable() return [] def load_options(self, model, vae=None): if self.master: return if self.model_override is not None: model = self.model_override model_name = re.sub(r'\s?\[[^]]*]$', '', model) payload = { "sd_model_checkpoint": model_name } if vae is not None: payload['sd_vae'] = vae self.state = State.WORKING start = time.time() response = self.session.post( self.full_url("options"), json=payload ) elapsed = time.time() - start self.state = State.IDLE if response.status_code != 200: logger.debug(f"failed to load options for worker '{self.label}'") else: logger.debug(f"worker '{self.label}' loaded weights in {elapsed:.2f}s") self.loaded_model = model_name if vae is not None: self.loaded_vae = vae return response def restart(self) -> bool: err_msg = f"could not restart worker '{self.label}'" success_msg = f"worker '{self.label}' is restarting" if self.master: # shouldn't really need to restart master (unless for convenience at some point) return True response = None try: response = self.session.post(self.full_url("server-restart"), timeout=3) except requests.ConnectionError: # the successful case (kinda) # have to assume that the worker is actually restarting because currently sdwui does not gracefully close # the connection logger.info(success_msg) return True except requests.RequestException as e: logger.error(f"{err_msg}:\n{e}") return False if response.status_code == 200: logger.info(success_msg) return True elif response.status_code == 404: logger.error(f"try adding --api-server-stop to '{self.label}'s launch arguments (couldn't restart)\n" "*requires webui version 1.5(5be6c02) or later") return False logger.error(f"{err_msg}: {response}") return False