diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6a387c3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ + +*.pyc +workers.json +config.json \ No newline at end of file diff --git a/preload.py b/preload.py index 89be003..4348caa 100644 --- a/preload.py +++ b/preload.py @@ -1,3 +1,8 @@ +import os +from pathlib import Path +from inspect import getsourcefile +from os.path import abspath + def preload(parser): parser.add_argument( "--distributed-remotes", @@ -23,3 +28,11 @@ def preload(parser): help="Enable debug information", action="store_true" ) + extension_path = Path(abspath(getsourcefile(lambda: 0))).parent.parent + config_path = extension_path.joinpath('config.json') + # add config file + parser.add_argument( + "--distributed-config", + help="config file to load / save, default: $EXTENSION_PATH/config.json", + default=config_path + ) diff --git a/scripts/extension.py b/scripts/extension.py index a870231..8f32dca 100644 --- a/scripts/extension.py +++ b/scripts/extension.py @@ -54,6 +54,7 @@ class Script(scripts.Script): world.add_worker(uuid=worker[0], address=worker[1], port=worker[2]) world.load_config() + assert world.has_any_workers, "No workers are available. (Try using `--distributed-remotes`?)" def title(self): return "Distribute" @@ -233,6 +234,7 @@ class Script(scripts.Script): # strip scripts that aren't yet supported and warn user packed_script_args: List[dict] = [] # list of api formatted per-script argument objects + # { "script_name": { "args": ["value1", "value2", ...] } for script in p.scripts.scripts: if script.alwayson is not True: continue @@ -253,6 +255,12 @@ class Script(scripts.Script): continue else: + # other scripts to pack + args_script_pack = {} + args_script_pack[title] = {"args": []} + for arg in p.script_args[script.args_from:script.args_to]: + args_script_pack[title]["args"].append(arg) + packed_script_args.append(args_script_pack) # https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/issues/12#issuecomment-1480382514 if Script.runs_since_init < 1: logger.warning(f"Distributed doesn't yet support '{title}'") diff --git a/scripts/spartan/UI.py b/scripts/spartan/UI.py index cfc8b25..e55defb 100644 --- a/scripts/spartan/UI.py +++ b/scripts/spartan/UI.py @@ -230,7 +230,7 @@ class UI: thin_client_cbx = gradio.Checkbox( label='Thin-client mode (experimental)', info="Only generate images using remote workers. There will be no previews when enabled.", - value=self.world.thin_client_mode + value=False ) job_timeout = gradio.Number( label='Job timeout', value=self.world.job_timeout, diff --git a/scripts/spartan/Worker.py b/scripts/spartan/Worker.py index 7d83402..9d8c89c 100644 --- a/scripts/spartan/Worker.py +++ b/scripts/spartan/Worker.py @@ -2,7 +2,7 @@ import io import gradio import requests -from typing import List, Union +from typing import List, Tuple, Union import math import copy import time @@ -47,29 +47,33 @@ class Worker: queried (bool): Whether this worker's memory status has been polled yet. Defaults to False. verify_remotes (bool): Whether to verify the validity of remote worker certificates. Defaults to False. master (bool): Whether this worker is the master node. Defaults to False. + auth (str|None): The username and password used to authenticate with the worker. Defaults to None. (username:password) benchmarked (bool): Whether this worker has been benchmarked. Defaults to False. # TODO should be the last MPE from the last session eta_percent_error (List[float]): A runtime list of ETA percent errors for this worker. Empty by default last_mpe (float): The last mean percent error for this worker. Defaults to None. response (requests.Response): The last response from this worker. Defaults to None. + + Raises: + InvalidWorkerResponse: If the worker responds with an invalid or unexpected response. """ - address: str = None - port: int = None - avg_ipm: float = None - uuid: str = None + address: Union[str, None] = None + port: int = 80 + avg_ipm: Union[float, None] = None + uuid: Union[str, None] = None queried: bool = False # whether this worker has been connected to yet + free_vram: Union[bytes, int] = 0 verify_remotes: bool = False master: bool = False benchmarked: bool = False eta_percent_error: List[float] = [] - last_mpe: float = None - response: requests.Response = None - loaded_model: str = None - loaded_vae: str = None - state: State = None + last_mpe: Union[float,None] = None + response: Union[requests.Response, None] = None + loaded_model: Union[str, None] = None + loaded_vae: Union[str, None] = None + state: Union[State, None] = None tls: bool = False - # Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A. # We compare to euler a because that is what we currently benchmark each node with. other_to_euler_a = { @@ -93,8 +97,18 @@ class Worker: "PLMS": 9.31 } - def __init__(self, address: str = None, port: int = None, uuid: str = None, verify_remotes: bool = None, - master: bool = False, tls: bool = False): + def __init__(self, address: Union[str, None] = None, port: int = 80, uuid: Union[str, None] = None, verify_remotes: bool = True, + master: bool = False, tls: bool = False, auth: Union[str, None, Tuple] = None): + """ + Creates a new worker object. + + param address: The address of the worker node. Can be an ip or a FQDN. Defaults to None. do NOT include sdapi/v1 in the address. + param port: The port number used by the worker node. Defaults to 80. (http) or 443 (https) + param uuid: The unique identifier/name of the worker node. Defaults to None. + param verify_remotes: Whether to verify the validity of remote worker certificates. Defaults to True. + param master: Whether this worker is the master node. Defaults to False. + param auth: The username and password used to authenticate with the worker. Defaults to None. (username:password) + """ if master is True: self.master = master self.uuid = 'master' @@ -106,7 +120,19 @@ class Worker: else: self.port = cmd_opts.port return - + # strip http:// or https:// from address if present + self.tls = tls + if address is not None: + if address.startswith("http://"): + address = address[7:] + elif address.startswith("https://"): + address = address[8:] + self.tls = True + self.port = 443 + # remove '/' from end of address if present + if address is not None: + if address.endswith('/'): + address = address[:-1] self.address = address self.port = port self.verify_remotes = verify_remotes @@ -114,13 +140,34 @@ class Worker: self.loaded_model = '' self.loaded_vae = '' self.state = State.IDLE - self.tls = tls self.model_override: str = None - + if auth is not None: + if isinstance(auth, str): + self.user = auth.split(':')[0] + self.password = auth.split(':')[1] + elif isinstance(auth, tuple): + self.user = auth[0] + self.password = auth[1] + else: + raise ValueError(f"Invalid auth value: {auth}") + self.auth: Union[Tuple[str, str] , None] = (self.user, self.password) if self.user is not None else None if uuid is not None: self.uuid = uuid + self.session = requests.Session() + self.session.auth = self.auth + logger.debug(f"worker '{self.uuid}' created with address '{self.full_url('')}'") + if self.verify_remotes: + # check user/ GET response + response = self.session.get( + self.full_url("memory"), + verify=self.verify_remotes + ) + if response.status_code != 200: + raise InvalidWorkerResponse(f"Worker '{self.uuid}' responded with status code {response.status_code}") def __str__(self): + if self.port is None or self.port == 80: + return f"{self.address}" return f"{self.address}:{self.port}" def info(self) -> dict: @@ -163,7 +210,6 @@ class Worker: Returns: str: The full url. """ - protocol = 'http' if not self.tls else 'https' return f"{protocol}://{self.__str__()}/sdapi/v1/{route}" @@ -255,7 +301,7 @@ class Worker: option_payload (dict): The options payload. sync_options (bool): Whether to attempt to synchronize the worker's loaded models with the locals' """ - eta = None + eta = 0 # TODO detect remote out of memory exception and restart or garbage collect instance using api? try: @@ -264,10 +310,11 @@ class Worker: # query memory available on worker and store for future reference if self.queried is False: self.queried = True - memory_response = requests.get( + memory_response = self.session.get( self.full_url("memory"), verify=self.verify_remotes ) + #curl -X GET "http://localhost:7860/memory" -H "accept: application/json" memory_response = memory_response.json() try: memory_response = memory_response['cuda']['system'] # all in bytes @@ -335,7 +382,7 @@ class Worker: response_queue = queue.Queue() def preemptable_request(response_queue): try: - response = requests.post( + response = self.session.post( self.full_url("txt2img") if init_images is None else self.full_url("img2img"), json=payload, verify=self.verify_remotes @@ -401,7 +448,7 @@ class Worker: self.state = State.IDLE return - def benchmark(self) -> int: + def benchmark(self) -> float: """ given a worker, run a small benchmark and return its performance in images/minute makes standard request(s) of 512x512 images and averages them to get the result @@ -454,26 +501,26 @@ class Worker: # average the sample results for accuracy ipm_sum = 0 - for ipm in results: - ipm_sum += ipm - avg_ipm = ipm_sum / samples + for ipm_result in results: + ipm_sum += ipm_result + avg_ipm_result = ipm_sum / samples - logger.debug(f"Worker '{self.uuid}' average ipm: {avg_ipm}") - self.avg_ipm = avg_ipm + logger.debug(f"Worker '{self.uuid}' average ipm: {avg_ipm_result}") + self.avg_ipm = avg_ipm_result # noinspection PyTypeChecker self.response = None self.benchmarked = True self.state = State.IDLE - return avg_ipm + return avg_ipm_result def refresh_checkpoints(self): try: - model_response = requests.post( + model_response = self.session.post( self.full_url('refresh-checkpoints'), json={}, verify=self.verify_remotes ) - lora_response = requests.post( + lora_response = self.session.post( self.full_url('refresh-loras'), json={}, verify=self.verify_remotes @@ -489,7 +536,7 @@ class Worker: def interrupt(self): try: - response = requests.post( + response = self.session.post( self.full_url('interrupt'), json={}, verify=self.verify_remotes @@ -504,7 +551,7 @@ class Worker: def reachable(self) -> bool: """returns false if worker is unreachable""" try: - response = requests.get( + response = self.session.get( self.full_url("memory"), verify=self.verify_remotes, timeout=3 diff --git a/scripts/spartan/World.py b/scripts/spartan/World.py index 71bf939..cc92234 100644 --- a/scripts/spartan/World.py +++ b/scripts/spartan/World.py @@ -9,15 +9,15 @@ import copy import json import os import time -from typing import List +from typing import List, Union from threading import Thread from inspect import getsourcefile from os.path import abspath from pathlib import Path from modules.processing import process_images, StableDiffusionProcessingTxt2Img import modules.shared as shared -from scripts.spartan.Worker import Worker, State -from scripts.spartan.shared import logger, warmup_samples +from scripts.spartan.Worker import InvalidWorkerResponse, Worker, State +from scripts.spartan.shared import logger, warmup_samples, benchmark_payload import scripts.spartan.shared as sh @@ -70,7 +70,7 @@ class World: # I'd rather keep the sdwui root directory clean. extension_path = Path(abspath(getsourcefile(lambda: 0))).parent.parent.parent - config_path = extension_path.joinpath('config.json') + config_path = shared.cmd_opts.distributed_config def __init__(self, initial_payload, verify_remotes: bool = True): self.master_worker = Worker(master=True) @@ -82,6 +82,7 @@ class World: self.verify_remotes = verify_remotes self.initial_payload = copy.copy(initial_payload) self.thin_client_mode = False + self.has_any_workers = False # whether any workers have been added to the world def __getitem__(self, label: str) -> Worker: for worker in self._workers: @@ -141,9 +142,11 @@ class World: for job in self.jobs: if job.worker.master: return job + + raise Exception("Master job not found") # TODO better way of merging/updating workers - def add_worker(self, uuid: str, address: str, port: int, tls: bool = False): + def add_worker(self, uuid: str, address: str, port: int, auth: Union[str,None] = None, tls: bool = False): """ Registers a worker with the world. @@ -151,10 +154,15 @@ class World: uuid (str): The name or unique identifier. address (str): The ip or FQDN. port (int): The port number. + + Returns: + Worker: The worker object. + + Raises: + InvalidWorkerResponse: If the worker is not valid. """ - original = None - new = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes, tls=tls) + new = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes, tls=tls, auth=auth) for w in self._workers: if w.uuid == uuid: @@ -162,6 +170,7 @@ class World: if original is None: self._workers.append(new) + self.has_any_workers = True return new else: original.address = address @@ -169,7 +178,6 @@ class World: original.tls = tls return original - def interrupt_remotes(self): for worker in self.get_workers(): @@ -368,8 +376,8 @@ class World: self.jobs.append(Job(worker=worker, batch_size=batch_size)) def get_workers(self): - filtered = [] - for worker in self._workers: + filtered:List[Worker] = [] + for worker in self.__workers: if worker.avg_ipm is not None and worker.avg_ipm <= 0: logger.warning(f"config reports invalid speed (0 ipm) for worker '{worker.uuid}', setting default of 1 ipm.\nplease re-benchmark") worker.avg_ipm = 1 @@ -504,18 +512,21 @@ class World: worker = self.add_worker( uuid=label, address=w['address'], - port=w['port'], - tls=w['tls'] + port=w.get('port', 80), + tls=w.get('tls', False), + auth =w.get('auth', None) ) worker.address = w['address'] - worker.port = w['port'] - worker.last_mpe = w['last_mpe'] - worker.avg_ipm = w['avg_ipm'] - worker.master = w['master'] + worker.port = w.get('port', 80) + worker.last_mpe = w.get('last_mpe', None) + worker.avg_ipm = w.get('avg_ipm', None) + worker.master = w.get('master', False) except KeyError as e: - raise e logger.error(f"invalid configuration in file for worker {w}... ignoring") continue + except InvalidWorkerResponse as e: + logger.error(f"worker {w} is invalid... ignoring") + continue logger.debug("loaded config") def save_config(self):