rebase patch for dev branch

pull/17/head
aria1th 2023-07-27 05:14:53 +09:00
parent 700be797e0
commit 3348fb0046
6 changed files with 132 additions and 49 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
*.pyc
workers.json
config.json

View File

@ -1,3 +1,8 @@
import os
from pathlib import Path
from inspect import getsourcefile
from os.path import abspath
def preload(parser):
parser.add_argument(
"--distributed-remotes",
@ -23,3 +28,11 @@ def preload(parser):
help="Enable debug information",
action="store_true"
)
extension_path = Path(abspath(getsourcefile(lambda: 0))).parent.parent
config_path = extension_path.joinpath('config.json')
# add config file
parser.add_argument(
"--distributed-config",
help="config file to load / save, default: $EXTENSION_PATH/config.json",
default=config_path
)

View File

@ -54,6 +54,7 @@ class Script(scripts.Script):
world.add_worker(uuid=worker[0], address=worker[1], port=worker[2])
world.load_config()
assert world.has_any_workers, "No workers are available. (Try using `--distributed-remotes`?)"
def title(self):
return "Distribute"
@ -233,6 +234,7 @@ class Script(scripts.Script):
# strip scripts that aren't yet supported and warn user
packed_script_args: List[dict] = [] # list of api formatted per-script argument objects
# { "script_name": { "args": ["value1", "value2", ...] }
for script in p.scripts.scripts:
if script.alwayson is not True:
continue
@ -253,6 +255,12 @@ class Script(scripts.Script):
continue
else:
# other scripts to pack
args_script_pack = {}
args_script_pack[title] = {"args": []}
for arg in p.script_args[script.args_from:script.args_to]:
args_script_pack[title]["args"].append(arg)
packed_script_args.append(args_script_pack)
# https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/issues/12#issuecomment-1480382514
if Script.runs_since_init < 1:
logger.warning(f"Distributed doesn't yet support '{title}'")

View File

@ -230,7 +230,7 @@ class UI:
thin_client_cbx = gradio.Checkbox(
label='Thin-client mode (experimental)',
info="Only generate images using remote workers. There will be no previews when enabled.",
value=self.world.thin_client_mode
value=False
)
job_timeout = gradio.Number(
label='Job timeout', value=self.world.job_timeout,

View File

@ -2,7 +2,7 @@ import io
import gradio
import requests
from typing import List, Union
from typing import List, Tuple, Union
import math
import copy
import time
@ -47,29 +47,33 @@ class Worker:
queried (bool): Whether this worker's memory status has been polled yet. Defaults to False.
verify_remotes (bool): Whether to verify the validity of remote worker certificates. Defaults to False.
master (bool): Whether this worker is the master node. Defaults to False.
auth (str|None): The username and password used to authenticate with the worker. Defaults to None. (username:password)
benchmarked (bool): Whether this worker has been benchmarked. Defaults to False.
# TODO should be the last MPE from the last session
eta_percent_error (List[float]): A runtime list of ETA percent errors for this worker. Empty by default
last_mpe (float): The last mean percent error for this worker. Defaults to None.
response (requests.Response): The last response from this worker. Defaults to None.
Raises:
InvalidWorkerResponse: If the worker responds with an invalid or unexpected response.
"""
address: str = None
port: int = None
avg_ipm: float = None
uuid: str = None
address: Union[str, None] = None
port: int = 80
avg_ipm: Union[float, None] = None
uuid: Union[str, None] = None
queried: bool = False # whether this worker has been connected to yet
free_vram: Union[bytes, int] = 0
verify_remotes: bool = False
master: bool = False
benchmarked: bool = False
eta_percent_error: List[float] = []
last_mpe: float = None
response: requests.Response = None
loaded_model: str = None
loaded_vae: str = None
state: State = None
last_mpe: Union[float,None] = None
response: Union[requests.Response, None] = None
loaded_model: Union[str, None] = None
loaded_vae: Union[str, None] = None
state: Union[State, None] = None
tls: bool = False
# Percentages representing (roughly) how much faster a given sampler is in comparison to Euler A.
# We compare to euler a because that is what we currently benchmark each node with.
other_to_euler_a = {
@ -93,8 +97,18 @@ class Worker:
"PLMS": 9.31
}
def __init__(self, address: str = None, port: int = None, uuid: str = None, verify_remotes: bool = None,
master: bool = False, tls: bool = False):
def __init__(self, address: Union[str, None] = None, port: int = 80, uuid: Union[str, None] = None, verify_remotes: bool = True,
master: bool = False, tls: bool = False, auth: Union[str, None, Tuple] = None):
"""
Creates a new worker object.
param address: The address of the worker node. Can be an ip or a FQDN. Defaults to None. do NOT include sdapi/v1 in the address.
param port: The port number used by the worker node. Defaults to 80. (http) or 443 (https)
param uuid: The unique identifier/name of the worker node. Defaults to None.
param verify_remotes: Whether to verify the validity of remote worker certificates. Defaults to True.
param master: Whether this worker is the master node. Defaults to False.
param auth: The username and password used to authenticate with the worker. Defaults to None. (username:password)
"""
if master is True:
self.master = master
self.uuid = 'master'
@ -106,7 +120,19 @@ class Worker:
else:
self.port = cmd_opts.port
return
# strip http:// or https:// from address if present
self.tls = tls
if address is not None:
if address.startswith("http://"):
address = address[7:]
elif address.startswith("https://"):
address = address[8:]
self.tls = True
self.port = 443
# remove '/' from end of address if present
if address is not None:
if address.endswith('/'):
address = address[:-1]
self.address = address
self.port = port
self.verify_remotes = verify_remotes
@ -114,13 +140,34 @@ class Worker:
self.loaded_model = ''
self.loaded_vae = ''
self.state = State.IDLE
self.tls = tls
self.model_override: str = None
if auth is not None:
if isinstance(auth, str):
self.user = auth.split(':')[0]
self.password = auth.split(':')[1]
elif isinstance(auth, tuple):
self.user = auth[0]
self.password = auth[1]
else:
raise ValueError(f"Invalid auth value: {auth}")
self.auth: Union[Tuple[str, str] , None] = (self.user, self.password) if self.user is not None else None
if uuid is not None:
self.uuid = uuid
self.session = requests.Session()
self.session.auth = self.auth
logger.debug(f"worker '{self.uuid}' created with address '{self.full_url('')}'")
if self.verify_remotes:
# check user/ GET response
response = self.session.get(
self.full_url("memory"),
verify=self.verify_remotes
)
if response.status_code != 200:
raise InvalidWorkerResponse(f"Worker '{self.uuid}' responded with status code {response.status_code}")
def __str__(self):
if self.port is None or self.port == 80:
return f"{self.address}"
return f"{self.address}:{self.port}"
def info(self) -> dict:
@ -163,7 +210,6 @@ class Worker:
Returns:
str: The full url.
"""
protocol = 'http' if not self.tls else 'https'
return f"{protocol}://{self.__str__()}/sdapi/v1/{route}"
@ -255,7 +301,7 @@ class Worker:
option_payload (dict): The options payload.
sync_options (bool): Whether to attempt to synchronize the worker's loaded models with the locals'
"""
eta = None
eta = 0
# TODO detect remote out of memory exception and restart or garbage collect instance using api?
try:
@ -264,10 +310,11 @@ class Worker:
# query memory available on worker and store for future reference
if self.queried is False:
self.queried = True
memory_response = requests.get(
memory_response = self.session.get(
self.full_url("memory"),
verify=self.verify_remotes
)
#curl -X GET "http://localhost:7860/memory" -H "accept: application/json"
memory_response = memory_response.json()
try:
memory_response = memory_response['cuda']['system'] # all in bytes
@ -335,7 +382,7 @@ class Worker:
response_queue = queue.Queue()
def preemptable_request(response_queue):
try:
response = requests.post(
response = self.session.post(
self.full_url("txt2img") if init_images is None else self.full_url("img2img"),
json=payload,
verify=self.verify_remotes
@ -401,7 +448,7 @@ class Worker:
self.state = State.IDLE
return
def benchmark(self) -> int:
def benchmark(self) -> float:
"""
given a worker, run a small benchmark and return its performance in images/minute
makes standard request(s) of 512x512 images and averages them to get the result
@ -454,26 +501,26 @@ class Worker:
# average the sample results for accuracy
ipm_sum = 0
for ipm in results:
ipm_sum += ipm
avg_ipm = ipm_sum / samples
for ipm_result in results:
ipm_sum += ipm_result
avg_ipm_result = ipm_sum / samples
logger.debug(f"Worker '{self.uuid}' average ipm: {avg_ipm}")
self.avg_ipm = avg_ipm
logger.debug(f"Worker '{self.uuid}' average ipm: {avg_ipm_result}")
self.avg_ipm = avg_ipm_result
# noinspection PyTypeChecker
self.response = None
self.benchmarked = True
self.state = State.IDLE
return avg_ipm
return avg_ipm_result
def refresh_checkpoints(self):
try:
model_response = requests.post(
model_response = self.session.post(
self.full_url('refresh-checkpoints'),
json={},
verify=self.verify_remotes
)
lora_response = requests.post(
lora_response = self.session.post(
self.full_url('refresh-loras'),
json={},
verify=self.verify_remotes
@ -489,7 +536,7 @@ class Worker:
def interrupt(self):
try:
response = requests.post(
response = self.session.post(
self.full_url('interrupt'),
json={},
verify=self.verify_remotes
@ -504,7 +551,7 @@ class Worker:
def reachable(self) -> bool:
"""returns false if worker is unreachable"""
try:
response = requests.get(
response = self.session.get(
self.full_url("memory"),
verify=self.verify_remotes,
timeout=3

View File

@ -9,15 +9,15 @@ import copy
import json
import os
import time
from typing import List
from typing import List, Union
from threading import Thread
from inspect import getsourcefile
from os.path import abspath
from pathlib import Path
from modules.processing import process_images, StableDiffusionProcessingTxt2Img
import modules.shared as shared
from scripts.spartan.Worker import Worker, State
from scripts.spartan.shared import logger, warmup_samples
from scripts.spartan.Worker import InvalidWorkerResponse, Worker, State
from scripts.spartan.shared import logger, warmup_samples, benchmark_payload
import scripts.spartan.shared as sh
@ -70,7 +70,7 @@ class World:
# I'd rather keep the sdwui root directory clean.
extension_path = Path(abspath(getsourcefile(lambda: 0))).parent.parent.parent
config_path = extension_path.joinpath('config.json')
config_path = shared.cmd_opts.distributed_config
def __init__(self, initial_payload, verify_remotes: bool = True):
self.master_worker = Worker(master=True)
@ -82,6 +82,7 @@ class World:
self.verify_remotes = verify_remotes
self.initial_payload = copy.copy(initial_payload)
self.thin_client_mode = False
self.has_any_workers = False # whether any workers have been added to the world
def __getitem__(self, label: str) -> Worker:
for worker in self._workers:
@ -141,9 +142,11 @@ class World:
for job in self.jobs:
if job.worker.master:
return job
raise Exception("Master job not found")
# TODO better way of merging/updating workers
def add_worker(self, uuid: str, address: str, port: int, tls: bool = False):
def add_worker(self, uuid: str, address: str, port: int, auth: Union[str,None] = None, tls: bool = False):
"""
Registers a worker with the world.
@ -151,10 +154,15 @@ class World:
uuid (str): The name or unique identifier.
address (str): The ip or FQDN.
port (int): The port number.
Returns:
Worker: The worker object.
Raises:
InvalidWorkerResponse: If the worker is not valid.
"""
original = None
new = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes, tls=tls)
new = Worker(uuid=uuid, address=address, port=port, verify_remotes=self.verify_remotes, tls=tls, auth=auth)
for w in self._workers:
if w.uuid == uuid:
@ -162,6 +170,7 @@ class World:
if original is None:
self._workers.append(new)
self.has_any_workers = True
return new
else:
original.address = address
@ -169,7 +178,6 @@ class World:
original.tls = tls
return original
def interrupt_remotes(self):
for worker in self.get_workers():
@ -368,8 +376,8 @@ class World:
self.jobs.append(Job(worker=worker, batch_size=batch_size))
def get_workers(self):
filtered = []
for worker in self._workers:
filtered:List[Worker] = []
for worker in self.__workers:
if worker.avg_ipm is not None and worker.avg_ipm <= 0:
logger.warning(f"config reports invalid speed (0 ipm) for worker '{worker.uuid}', setting default of 1 ipm.\nplease re-benchmark")
worker.avg_ipm = 1
@ -504,18 +512,21 @@ class World:
worker = self.add_worker(
uuid=label,
address=w['address'],
port=w['port'],
tls=w['tls']
port=w.get('port', 80),
tls=w.get('tls', False),
auth =w.get('auth', None)
)
worker.address = w['address']
worker.port = w['port']
worker.last_mpe = w['last_mpe']
worker.avg_ipm = w['avg_ipm']
worker.master = w['master']
worker.port = w.get('port', 80)
worker.last_mpe = w.get('last_mpe', None)
worker.avg_ipm = w.get('avg_ipm', None)
worker.master = w.get('master', False)
except KeyError as e:
raise e
logger.error(f"invalid configuration in file for worker {w}... ignoring")
continue
except InvalidWorkerResponse as e:
logger.error(f"worker {w} is invalid... ignoring")
continue
logger.debug("loaded config")
def save_config(self):