rebase patch for dev branch
parent
700be797e0
commit
3348fb0046
|
|
@ -0,0 +1,4 @@
|
|||
|
||||
*.pyc
|
||||
workers.json
|
||||
config.json
|
||||
13
preload.py
13
preload.py
|
|
@ -1,3 +1,8 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from inspect import getsourcefile
|
||||
from os.path import abspath
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument(
|
||||
"--distributed-remotes",
|
||||
|
|
@ -23,3 +28,11 @@ def preload(parser):
|
|||
help="Enable debug information",
|
||||
action="store_true"
|
||||
)
|
||||
extension_path = Path(abspath(getsourcefile(lambda: 0))).parent.parent
|
||||
config_path = extension_path.joinpath('config.json')
|
||||
# add config file
|
||||
parser.add_argument(
|
||||
"--distributed-config",
|
||||
help="config file to load / save, default: $EXTENSION_PATH/config.json",
|
||||
default=config_path
|
||||
)
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ class Script(scripts.Script):
|
|||
world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
|
||||
|
||||
world.load_config()
|
||||
assert world.has_any_workers, "No workers are available. (Try using `--distributed-remotes`?)"
|
||||
|
||||
def title(self):
|
||||
return "Distribute"
|
||||
|
|
@ -233,6 +234,7 @@ class Script(scripts.Script):
|
|||
|
||||
# strip scripts that aren't yet supported and warn user
|
||||
packed_script_args: List[dict] = [] # list of api formatted per-script argument objects
|
||||
# { "script_name": { "args": ["value1", "value2", ...] }
|
||||
for script in p.scripts.scripts:
|
||||
if script.alwayson is not True:
|
||||
continue
|
||||
|
|
@ -253,6 +255,12 @@ class Script(scripts.Script):
|
|||
|
||||
continue
|
||||
else:
|
||||
# other scripts to pack
|
||||
args_script_pack = {}
|
||||
args_script_pack[title] = {"args": []}
|
||||
for arg in p.script_args[script.args_from:script.args_to]:
|
||||
args_script_pack[title]["args"].append(arg)
|
||||
packed_script_args.append(args_script_pack)
|
||||
# https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/issues/12#issuecomment-1480382514
|
||||
if Script.runs_since_init < 1:
|
||||
logger.warning(f"Distributed doesn't yet support '{title}'")
|
||||
|
|
|
|||
|
|
@ -230,7 +230,7 @@ class UI:
|
|||
thin_client_cbx = gradio.Checkbox(
|
||||
label='Thin-client mode (experimental)',
|
||||
info="Only generate images using remote workers. There will be no previews when enabled.",
|
||||
value=self.world.thin_client_mode
|
||||
value=False
|
||||
)
|
||||
job_timeout = gradio.Number(
|
||||
label='Job timeout', value=self.world.job_timeout,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import io
|
|||
|
||||
import gradio
|
||||
import requests
|
||||
from typing import List, Union
|
||||
from typing import List, Tuple, Union
|
||||
import math
|
||||
import copy
|
||||
import time
|
||||
|
|
@ -47,29 +47,33 @@ class Worker:
|
|||
queried (bool): Whether this worker's memory status has been polled yet. Defaults to False.
|
||||
verify_remotes (bool): Whether to verify the validity of remote worker certificates. Defaults to False.
|
||||
master (bool): Whether this worker is the master node. Defaults to False.
|
||||
auth (str|None): The username and password used to authenticate with the worker. Defaults to None. (username:password)
|
||||
benchmarked (bool): Whether this worker has been benchmarked. Defaults to False.
|
||||
# TODO should be the last MPE from the last session
|
||||
eta_percent_error (List[float]): A runtime list of ETA percent errors for this worker. Empty by default
|
||||
last_mpe (float): The last mean percent error for this worker. Defaults to None.
|
||||
response (requests.Response): The last response from this worker. Defaults to None.
|
||||
|
||||
Raises:
|
||||
InvalidWorkerResponse: If the worker responds with an invalid or unexpected response.
|
||||
"""
|
||||
|
||||
address: str = None
|
||||
port: int = None
|
||||
avg_ipm: float = None
|
||||
uuid: str = None
|
||||
address: Union[str, None] = None
|
||||
port: int = 80
|
||||
avg_ipm: Union[float, None] = None
|
||||
uuid: Union[str, None] = None
|
||||
queried: bool = False # whether this worker has been connected to yet
|
||||
free_vram: Union[bytes, int] = 0
|
||||
verify_remotes: bool = False
|
||||
master: bool = False
|
||||
benchmarked: bool = False
|
||||
eta_percent_error: List[float] = []
|
||||
last_mpe: float = None
|
||||
response: requests.Response = None
|
||||
loaded_model: str = None
|
||||
loaded_vae: str = None
|
||||
state: State = None
|
||||
last_mpe: Union[float,None] = None
|
||||
response: Union[requests.Response, None] = None
|
||||
loaded_model: Union[str, None] = None
|
||||
loaded_vae: Union[str, None] = None
|
||||
state: Union[State, None] = None
|
||||
tls: bool = False
|
||||
|
||||
# Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A.
|
||||
# We compare to euler a because that is what we currently benchmark each node with.
|
||||
other_to_euler_a = {
|
||||
|
|
@ -93,8 +97,18 @@ class Worker:
|
|||
"PLMS": 9.31
|
||||
}
|
||||
|
||||
def __init__(self, address: str = None, port: int = None, uuid: str = None, verify_remotes: bool = None,
|
||||
master: bool = False, tls: bool = False):
|
||||
def __init__(self, address: Union[str, None] = None, port: int = 80, uuid: Union[str, None] = None, verify_remotes: bool = True,
|
||||
master: bool = False, tls: bool = False, auth: Union[str, None, Tuple] = None):
|
||||
"""
|
||||
Creates a new worker object.
|
||||
|
||||
param address: The address of the worker node. Can be an ip or a FQDN. Defaults to None. do NOT include sdapi/v1 in the address.
|
||||
param port: The port number used by the worker node. Defaults to 80. (http) or 443 (https)
|
||||
param uuid: The unique identifier/name of the worker node. Defaults to None.
|
||||
param verify_remotes: Whether to verify the validity of remote worker certificates. Defaults to True.
|
||||
param master: Whether this worker is the master node. Defaults to False.
|
||||
param auth: The username and password used to authenticate with the worker. Defaults to None. (username:password)
|
||||
"""
|
||||
if master is True:
|
||||
self.master = master
|
||||
self.uuid = 'master'
|
||||
|
|
@ -106,7 +120,19 @@ class Worker:
|
|||
else:
|
||||
self.port = cmd_opts.port
|
||||
return
|
||||
|
||||
# strip http:// or https:// from address if present
|
||||
self.tls = tls
|
||||
if address is not None:
|
||||
if address.startswith("http://"):
|
||||
address = address[7:]
|
||||
elif address.startswith("https://"):
|
||||
address = address[8:]
|
||||
self.tls = True
|
||||
self.port = 443
|
||||
# remove '/' from end of address if present
|
||||
if address is not None:
|
||||
if address.endswith('/'):
|
||||
address = address[:-1]
|
||||
self.address = address
|
||||
self.port = port
|
||||
self.verify_remotes = verify_remotes
|
||||
|
|
@ -114,13 +140,34 @@ class Worker:
|
|||
self.loaded_model = ''
|
||||
self.loaded_vae = ''
|
||||
self.state = State.IDLE
|
||||
self.tls = tls
|
||||
self.model_override: str = None
|
||||
|
||||
if auth is not None:
|
||||
if isinstance(auth, str):
|
||||
self.user = auth.split(':')[0]
|
||||
self.password = auth.split(':')[1]
|
||||
elif isinstance(auth, tuple):
|
||||
self.user = auth[0]
|
||||
self.password = auth[1]
|
||||
else:
|
||||
raise ValueError(f"Invalid auth value: {auth}")
|
||||
self.auth: Union[Tuple[str, str] , None] = (self.user, self.password) if self.user is not None else None
|
||||
if uuid is not None:
|
||||
self.uuid = uuid
|
||||
self.session = requests.Session()
|
||||
self.session.auth = self.auth
|
||||
logger.debug(f"worker '{self.uuid}' created with address '{self.full_url('')}'")
|
||||
if self.verify_remotes:
|
||||
# check user/ GET response
|
||||
response = self.session.get(
|
||||
self.full_url("memory"),
|
||||
verify=self.verify_remotes
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise InvalidWorkerResponse(f"Worker '{self.uuid}' responded with status code {response.status_code}")
|
||||
|
||||
def __str__(self):
|
||||
if self.port is None or self.port == 80:
|
||||
return f"{self.address}"
|
||||
return f"{self.address}:{self.port}"
|
||||
|
||||
def info(self) -> dict:
|
||||
|
|
@ -163,7 +210,6 @@ class Worker:
|
|||
Returns:
|
||||
str: The full url.
|
||||
"""
|
||||
|
||||
protocol = 'http' if not self.tls else 'https'
|
||||
return f"{protocol}://{self.__str__()}/sdapi/v1/{route}"
|
||||
|
||||
|
|
@ -255,7 +301,7 @@ class Worker:
|
|||
option_payload (dict): The options payload.
|
||||
sync_options (bool): Whether to attempt to synchronize the worker's loaded models with the locals'
|
||||
"""
|
||||
eta = None
|
||||
eta = 0
|
||||
|
||||
# TODO detect remote out of memory exception and restart or garbage collect instance using api?
|
||||
try:
|
||||
|
|
@ -264,10 +310,11 @@ class Worker:
|
|||
# query memory available on worker and store for future reference
|
||||
if self.queried is False:
|
||||
self.queried = True
|
||||
memory_response = requests.get(
|
||||
memory_response = self.session.get(
|
||||
self.full_url("memory"),
|
||||
verify=self.verify_remotes
|
||||
)
|
||||
#curl -X GET "http://localhost:7860/memory" -H "accept: application/json"
|
||||
memory_response = memory_response.json()
|
||||
try:
|
||||
memory_response = memory_response['cuda']['system'] # all in bytes
|
||||
|
|
@ -335,7 +382,7 @@ class Worker:
|
|||
response_queue = queue.Queue()
|
||||
def preemptable_request(response_queue):
|
||||
try:
|
||||
response = requests.post(
|
||||
response = self.session.post(
|
||||
self.full_url("txt2img") if init_images is None else self.full_url("img2img"),
|
||||
json=payload,
|
||||
verify=self.verify_remotes
|
||||
|
|
@ -401,7 +448,7 @@ class Worker:
|
|||
self.state = State.IDLE
|
||||
return
|
||||
|
||||
def benchmark(self) -> int:
|
||||
def benchmark(self) -> float:
|
||||
"""
|
||||
given a worker, run a small benchmark and return its performance in images/minute
|
||||
makes standard request(s) of 512x512 images and averages them to get the result
|
||||
|
|
@ -454,26 +501,26 @@ class Worker:
|
|||
|
||||
# average the sample results for accuracy
|
||||
ipm_sum = 0
|
||||
for ipm in results:
|
||||
ipm_sum += ipm
|
||||
avg_ipm = ipm_sum / samples
|
||||
for ipm_result in results:
|
||||
ipm_sum += ipm_result
|
||||
avg_ipm_result = ipm_sum / samples
|
||||
|
||||
logger.debug(f"Worker '{self.uuid}' average ipm: {avg_ipm}")
|
||||
self.avg_ipm = avg_ipm
|
||||
logger.debug(f"Worker '{self.uuid}' average ipm: {avg_ipm_result}")
|
||||
self.avg_ipm = avg_ipm_result
|
||||
# noinspection PyTypeChecker
|
||||
self.response = None
|
||||
self.benchmarked = True
|
||||
self.state = State.IDLE
|
||||
return avg_ipm
|
||||
return avg_ipm_result
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
try:
|
||||
model_response = requests.post(
|
||||
model_response = self.session.post(
|
||||
self.full_url('refresh-checkpoints'),
|
||||
json={},
|
||||
verify=self.verify_remotes
|
||||
)
|
||||
lora_response = requests.post(
|
||||
lora_response = self.session.post(
|
||||
self.full_url('refresh-loras'),
|
||||
json={},
|
||||
verify=self.verify_remotes
|
||||
|
|
@ -489,7 +536,7 @@ class Worker:
|
|||
|
||||
def interrupt(self):
|
||||
try:
|
||||
response = requests.post(
|
||||
response = self.session.post(
|
||||
self.full_url('interrupt'),
|
||||
json={},
|
||||
verify=self.verify_remotes
|
||||
|
|
@ -504,7 +551,7 @@ class Worker:
|
|||
def reachable(self) -> bool:
|
||||
"""returns false if worker is unreachable"""
|
||||
try:
|
||||
response = requests.get(
|
||||
response = self.session.get(
|
||||
self.full_url("memory"),
|
||||
verify=self.verify_remotes,
|
||||
timeout=3
|
||||
|
|
|
|||
|
|
@ -9,15 +9,15 @@ import copy
|
|||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
from threading import Thread
|
||||
from inspect import getsourcefile
|
||||
from os.path import abspath
|
||||
from pathlib import Path
|
||||
from modules.processing import process_images, StableDiffusionProcessingTxt2Img
|
||||
import modules.shared as shared
|
||||
from scripts.spartan.Worker import Worker, State
|
||||
from scripts.spartan.shared import logger, warmup_samples
|
||||
from scripts.spartan.Worker import InvalidWorkerResponse, Worker, State
|
||||
from scripts.spartan.shared import logger, warmup_samples, benchmark_payload
|
||||
import scripts.spartan.shared as sh
|
||||
|
||||
|
||||
|
|
@ -70,7 +70,7 @@ class World:
|
|||
|
||||
# I'd rather keep the sdwui root directory clean.
|
||||
extension_path = Path(abspath(getsourcefile(lambda: 0))).parent.parent.parent
|
||||
config_path = extension_path.joinpath('config.json')
|
||||
config_path = shared.cmd_opts.distributed_config
|
||||
|
||||
def __init__(self, initial_payload, verify_remotes: bool = True):
|
||||
self.master_worker = Worker(master=True)
|
||||
|
|
@ -82,6 +82,7 @@ class World:
|
|||
self.verify_remotes = verify_remotes
|
||||
self.initial_payload = copy.copy(initial_payload)
|
||||
self.thin_client_mode = False
|
||||
self.has_any_workers = False # whether any workers have been added to the world
|
||||
|
||||
def __getitem__(self, label: str) -> Worker:
|
||||
for worker in self._workers:
|
||||
|
|
@ -141,9 +142,11 @@ class World:
|
|||
for job in self.jobs:
|
||||
if job.worker.master:
|
||||
return job
|
||||
|
||||
raise Exception("Master job not found")
|
||||
|
||||
# TODO better way of merging/updating workers
|
||||
def add_worker(self, uuid: str, address: str, port: int, tls: bool = False):
|
||||
def add_worker(self, uuid: str, address: str, port: int, auth: Union[str,None] = None, tls: bool = False):
|
||||
"""
|
||||
Registers a worker with the world.
|
||||
|
||||
|
|
@ -151,10 +154,15 @@ class World:
|
|||
uuid (str): The name or unique identifier.
|
||||
address (str): The ip or FQDN.
|
||||
port (int): The port number.
|
||||
|
||||
Returns:
|
||||
Worker: The worker object.
|
||||
|
||||
Raises:
|
||||
InvalidWorkerResponse: If the worker is not valid.
|
||||
"""
|
||||
|
||||
original = None
|
||||
new = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes, tls=tls)
|
||||
new = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes, tls=tls, auth=auth)
|
||||
|
||||
for w in self._workers:
|
||||
if w.uuid == uuid:
|
||||
|
|
@ -162,6 +170,7 @@ class World:
|
|||
|
||||
if original is None:
|
||||
self._workers.append(new)
|
||||
self.has_any_workers = True
|
||||
return new
|
||||
else:
|
||||
original.address = address
|
||||
|
|
@ -169,7 +178,6 @@ class World:
|
|||
original.tls = tls
|
||||
|
||||
return original
|
||||
|
||||
def interrupt_remotes(self):
|
||||
|
||||
for worker in self.get_workers():
|
||||
|
|
@ -368,8 +376,8 @@ class World:
|
|||
self.jobs.append(Job(worker=worker, batch_size=batch_size))
|
||||
|
||||
def get_workers(self):
|
||||
filtered = []
|
||||
for worker in self._workers:
|
||||
filtered:List[Worker] = []
|
||||
for worker in self.__workers:
|
||||
if worker.avg_ipm is not None and worker.avg_ipm <= 0:
|
||||
logger.warning(f"config reports invalid speed (0 ipm) for worker '{worker.uuid}', setting default of 1 ipm.\nplease re-benchmark")
|
||||
worker.avg_ipm = 1
|
||||
|
|
@ -504,18 +512,21 @@ class World:
|
|||
worker = self.add_worker(
|
||||
uuid=label,
|
||||
address=w['address'],
|
||||
port=w['port'],
|
||||
tls=w['tls']
|
||||
port=w.get('port', 80),
|
||||
tls=w.get('tls', False),
|
||||
auth =w.get('auth', None)
|
||||
)
|
||||
worker.address = w['address']
|
||||
worker.port = w['port']
|
||||
worker.last_mpe = w['last_mpe']
|
||||
worker.avg_ipm = w['avg_ipm']
|
||||
worker.master = w['master']
|
||||
worker.port = w.get('port', 80)
|
||||
worker.last_mpe = w.get('last_mpe', None)
|
||||
worker.avg_ipm = w.get('avg_ipm', None)
|
||||
worker.master = w.get('master', False)
|
||||
except KeyError as e:
|
||||
raise e
|
||||
logger.error(f"invalid configuration in file for worker {w}... ignoring")
|
||||
continue
|
||||
except InvalidWorkerResponse as e:
|
||||
logger.error(f"worker {w} is invalid... ignoring")
|
||||
continue
|
||||
logger.debug("loaded config")
|
||||
|
||||
def save_config(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue