diff --git a/preload.py b/preload.py index a685525..281f146 100644 --- a/preload.py +++ b/preload.py @@ -1,38 +1,38 @@ -import os from pathlib import Path from inspect import getsourcefile from os.path import abspath + def preload(parser): - parser.add_argument( - "--distributed-remotes", - nargs="+", - help="Enter n pairs of sockets", - type=lambda t: t.split(":") - ) + parser.add_argument( + "--distributed-remotes", + nargs="+", + help="Enter n pairs of sockets", + type=lambda t: t.split(":") + ) - parser.add_argument( - "--distributed-skip-verify-remotes", - help="Disable verification of remote worker TLS certificates", - action="store_true" - ) + parser.add_argument( + "--distributed-skip-verify-remotes", + help="Disable verification of remote worker TLS certificates", + action="store_true" + ) - parser.add_argument( - "--distributed-remotes-autosave", - help="Enable auto-saving of remote worker generations", - action="store_true" - ) + parser.add_argument( + "--distributed-remotes-autosave", + help="Enable auto-saving of remote worker generations", + action="store_true" + ) - parser.add_argument( - "--distributed-debug", - help="Enable debug information", - action="store_true" - ) - extension_path = Path(abspath(getsourcefile(lambda: 0))).parent - config_path = extension_path.joinpath('distributed-config.json') - # add config file - parser.add_argument( - "--distributed-config", - help="config file to load / save, default: $WEBUI_PATH/distributed-config.json", - default=config_path - ) + parser.add_argument( + "--distributed-debug", + help="Enable debug information", + action="store_true" + ) + extension_path = Path(abspath(getsourcefile(lambda: 0))).parent + config_path = extension_path.joinpath('distributed-config.json') + # add config file + parser.add_argument( + "--distributed-config", + help="config file to load / save, default: $WEBUI_PATH/distributed-config.json", + default=config_path + ) diff --git a/scripts/extension.py b/scripts/extension.py index 392619b..9a99d74 100644 --- a/scripts/extension.py +++ b/scripts/extension.py @@ -3,28 +3,28 @@ https://github.com/papuSpartan/stable-diffusion-webui-distributed """ import base64 +import copy import io import json import re -from modules import scripts -from modules import processing -from threading import Thread, current_thread -from PIL import Image -from typing import List, Callable -import urllib3 -import copy -from modules.images import save_image -from modules.shared import opts, cmd_opts -from modules.shared import state as webui_state -import time -from scripts.spartan.World import World, WorldAlreadyInitialized -from scripts.spartan.UI import UI -from scripts.spartan.shared import logger -from scripts.spartan.control_net import pack_control_net -from modules.processing import fix_seed, Processed import signal import sys +import time +from threading import Thread, current_thread +from typing import List import gradio +import urllib3 +from PIL import Image +from modules import processing +from modules import scripts +from modules.images import save_image +from modules.processing import fix_seed, Processed +from modules.shared import opts, cmd_opts +from modules.shared import state as webui_state +from scripts.spartan.control_net import pack_control_net +from scripts.spartan.shared import logger +from scripts.spartan.ui import UI +from scripts.spartan.world import World, WorldAlreadyInitialized old_sigint_handler = signal.getsignal(signal.SIGINT) old_sigterm_handler = signal.getsignal(signal.SIGTERM) @@ -35,7 +35,7 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM) class Script(scripts.Script): worker_threads: List[Thread] = [] # Whether to verify worker certificates. Can be useful if your remotes are self-signed. - verify_remotes = False if cmd_opts.distributed_skip_verify_remotes else True + verify_remotes = not cmd_opts.distributed_skip_verify_remotes is_img2img = True is_txt2img = True @@ -268,13 +268,13 @@ class Script(scripts.Script): packed_script_args.append(pack_control_net(cn_units)) continue - else: - # other scripts to pack - args_script_pack = {title: {"args": []}} - for arg in p.script_args[script.args_from:script.args_to]: - args_script_pack[title]["args"].append(arg) - packed_script_args.append(args_script_pack) - # https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/issues/12#issuecomment-1480382514 + + # other scripts to pack + args_script_pack = {title: {"args": []}} + for arg in p.script_args[script.args_from:script.args_to]: + args_script_pack[title]["args"].append(arg) + packed_script_args.append(args_script_pack) + # https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/issues/12#issuecomment-1480382514 # encapsulating the request object within a txt2imgreq object is deprecated and no longer works # see test/basic_features/txt2img_test.py for an example @@ -379,7 +379,8 @@ class Script(scripts.Script): def postprocess(p, processed, *args): if not Script.world.enabled: return - elif len(processed.images) >= 1 and Script.master_start is not None: + + if len(processed.images) >= 1 and Script.master_start is not None: Script.add_to_gallery(p=p, processed=processed) @staticmethod diff --git a/scripts/spartan/control_net.py b/scripts/spartan/control_net.py index fe2d99a..18a240e 100644 --- a/scripts/spartan/control_net.py +++ b/scripts/spartan/control_net.py @@ -1,6 +1,6 @@ -from modules.api.api import encode_pil_to_base64 -from PIL import Image import copy +from PIL import Image +from modules.api.api import encode_pil_to_base64 from scripts.spartan.shared import logger diff --git a/scripts/spartan/pmodels.py b/scripts/spartan/pmodels.py index b97b70a..92b5059 100644 --- a/scripts/spartan/pmodels.py +++ b/scripts/spartan/pmodels.py @@ -33,11 +33,11 @@ class Worker_Model(BaseModel): password: Optional[str] = Field(description="The password to be used when authenticating with this worker") pixel_cap: Optional[int] = Field(default=-1, description="Max amount of pixels to allow one worker to handle at the same time. -1 means there is no limit") -class Config_Model(BaseModel): +class ConfigModel(BaseModel): workers: List[Dict[str, Worker_Model]] benchmark_payload: Dict = Field( default=Benchmark_Payload, description='the payload used when benchmarking a node' ) job_timeout: Optional[int] = Field(default=3) - enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True) \ No newline at end of file + enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True) diff --git a/scripts/spartan/shared.py b/scripts/spartan/shared.py index 1c89e38..9e91214 100644 --- a/scripts/spartan/shared.py +++ b/scripts/spartan/shared.py @@ -1,18 +1,18 @@ import logging +from inspect import getsourcefile from logging import Handler from logging.handlers import RotatingFileHandler -from inspect import getsourcefile +from os.path import abspath +from pathlib import Path from typing import Union -from rich.logging import RichHandler from modules.shared import cmd_opts from pydantic import BaseModel, Field -from os.path import abspath +from rich.logging import RichHandler -from pathlib import Path extension_path = Path(abspath(getsourcefile(lambda: 0))).parent.parent.parent # https://rich.readthedocs.io/en/stable/logging.html -log_level = 'DEBUG' if cmd_opts.distributed_debug else 'INFO' +LOG_LEVEL = 'DEBUG' if cmd_opts.distributed_debug else 'INFO' logger = logging.getLogger("distributed") rich_handler = RichHandler( rich_tracebacks=True, @@ -21,7 +21,7 @@ rich_handler = RichHandler( keywords=["distributed", "Distributed", "worker", "Worker", "world", "World"] ) logger.propagate = False # prevent log duplication by webui since it now uses the logging module -logger.setLevel(log_level) +logger.setLevel(LOG_LEVEL) log_path = extension_path.joinpath('distributed.log') file_handler = RotatingFileHandler(filename=log_path, maxBytes=10_000_000, backupCount=1) formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') diff --git a/scripts/spartan/UI.py b/scripts/spartan/ui.py similarity index 98% rename from scripts/spartan/UI.py rename to scripts/spartan/ui.py index 05c49af..6082cd5 100644 --- a/scripts/spartan/UI.py +++ b/scripts/spartan/ui.py @@ -1,14 +1,12 @@ import os import subprocess from pathlib import Path -import gradio - -from .shared import logger, log_level, gui_handler -from .Worker import Worker, State -from modules.shared import state as webui_state -from modules.shared import opts -from typing import List from threading import Thread +import gradio +from modules.shared import opts +from modules.shared import state as webui_state +from .shared import logger, LOG_LEVEL, gui_handler +from .worker import State worker_select_dropdown = None @@ -174,7 +172,7 @@ class UI: worker.session.auth = (user, password) self.world.save_config() - def main_toggle_btn(self, toggle): + def main_toggle_btn(self): self.world.enabled = not self.world.enabled self.world.save_config() @@ -245,7 +243,7 @@ class UI: redo_benchmarks_btn.style(full_width=False) redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[]) - if log_level == 'DEBUG': + if LOG_LEVEL == 'DEBUG': clear_queue_btn = gradio.Button(value='Clear local webui queue', variant='stop') clear_queue_btn.style(full_width=False) clear_queue_btn.click(self.clear_queue_btn) diff --git a/scripts/spartan/Worker.py b/scripts/spartan/worker.py similarity index 99% rename from scripts/spartan/Worker.py rename to scripts/spartan/worker.py index 7083253..9000628 100644 --- a/scripts/spartan/Worker.py +++ b/scripts/spartan/worker.py @@ -1,19 +1,18 @@ -import io -# import gradio -import requests -from typing import List, Tuple, Union -import math -import copy -import time -from threading import Thread -from modules.shared import cmd_opts -from enum import Enum -import json import base64 +import copy +import io +import json +import math import queue -from modules.shared import state as master_state -from modules.api.api import encode_pil_to_base64 import re +import time +from enum import Enum +from threading import Thread +from typing import List, Union +import requests +from modules.api.api import encode_pil_to_base64 +from modules.shared import cmd_opts +from modules.shared import state as master_state from . import shared as sh from .shared import logger, warmup_samples @@ -438,8 +437,7 @@ class Worker: result = response_queue.get() if isinstance(result, Exception): raise result - else: - response = result + response = result self.response = response.json() if response.status_code != 200: @@ -600,10 +598,7 @@ class Worker: timeout=3, verify=not self.verify_remotes ) - if response.status_code == 200: - return True - else: - return False + return response.status_code == 200 except requests.exceptions.ConnectionError: return False diff --git a/scripts/spartan/World.py b/scripts/spartan/world.py similarity index 98% rename from scripts/spartan/World.py rename to scripts/spartan/world.py index 6d174c4..4bd548e 100644 --- a/scripts/spartan/World.py +++ b/scripts/spartan/world.py @@ -9,15 +9,15 @@ import copy import json import os import time -from typing import List, Dict, Union from threading import Thread +from typing import List import gradio -from modules.processing import process_images, StableDiffusionProcessingTxt2Img import modules.shared as shared -from .Worker import Worker, State -from .shared import logger, warmup_samples, extension_path -from .pmodels import Config_Model, Benchmark_Payload +from modules.processing import process_images, StableDiffusionProcessingTxt2Img from . import shared as sh +from .pmodels import ConfigModel, Benchmark_Payload +from .shared import logger, warmup_samples, extension_path +from .worker import Worker, State class NotBenchmarked(Exception): @@ -66,9 +66,8 @@ class Job: if pixels <= self.worker.pixel_cap: self.batch_size += batch_size return True - else: - logger.debug(f"worker {self.worker.label} hit pixel cap ({pixels} > cap: {self.worker.pixel_cap})") - return False + logger.debug(f"worker {self.worker.label} hit pixel cap ({pixels} > cap: {self.worker.pixel_cap})") + return False class World: @@ -370,7 +369,7 @@ class World: master_bench_payload.do_not_save_samples = True # "warm up" due to initial generation lag - for i in range(warmup_samples): + for _ in range(warmup_samples): process_images(master_bench_payload) # get actual sample @@ -610,7 +609,7 @@ class World: sh.benchmark_payload = Benchmark_Payload() return - config = Config_Model(**config_raw) + config = ConfigModel(**config_raw) # saves config schema to /distributed-config.schema.json # print(models.Config.schema_json()) @@ -634,7 +633,7 @@ class World: Saves the config file. """ - config = Config_Model( + config = ConfigModel( workers=[{worker.label: worker.model.dict()} for worker in self._workers], benchmark_payload=sh.benchmark_payload, job_timeout=self.job_timeout,