some formatting

master
unknown 2024-01-11 01:34:34 -06:00
parent 587843f89f
commit f786860cbc
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
8 changed files with 97 additions and 104 deletions

View File

@ -1,38 +1,38 @@
import os
from pathlib import Path from pathlib import Path
from inspect import getsourcefile from inspect import getsourcefile
from os.path import abspath from os.path import abspath
def preload(parser): def preload(parser):
parser.add_argument( parser.add_argument(
"--distributed-remotes", "--distributed-remotes",
nargs="+", nargs="+",
help="Enter n pairs of sockets", help="Enter n pairs of sockets",
type=lambda t: t.split(":") type=lambda t: t.split(":")
) )
parser.add_argument( parser.add_argument(
"--distributed-skip-verify-remotes", "--distributed-skip-verify-remotes",
help="Disable verification of remote worker TLS certificates", help="Disable verification of remote worker TLS certificates",
action="store_true" action="store_true"
) )
parser.add_argument( parser.add_argument(
"--distributed-remotes-autosave", "--distributed-remotes-autosave",
help="Enable auto-saving of remote worker generations", help="Enable auto-saving of remote worker generations",
action="store_true" action="store_true"
) )
parser.add_argument( parser.add_argument(
"--distributed-debug", "--distributed-debug",
help="Enable debug information", help="Enable debug information",
action="store_true" action="store_true"
) )
extension_path = Path(abspath(getsourcefile(lambda: 0))).parent extension_path = Path(abspath(getsourcefile(lambda: 0))).parent
config_path = extension_path.joinpath('distributed-config.json') config_path = extension_path.joinpath('distributed-config.json')
# add config file # add config file
parser.add_argument( parser.add_argument(
"--distributed-config", "--distributed-config",
help="config file to load / save, default: $WEBUI_PATH/distributed-config.json", help="config file to load / save, default: $WEBUI_PATH/distributed-config.json",
default=config_path default=config_path
) )

View File

@ -3,28 +3,28 @@ https://github.com/papuSpartan/stable-diffusion-webui-distributed
""" """
import base64 import base64
import copy
import io import io
import json import json
import re import re
from modules import scripts
from modules import processing
from threading import Thread, current_thread
from PIL import Image
from typing import List, Callable
import urllib3
import copy
from modules.images import save_image
from modules.shared import opts, cmd_opts
from modules.shared import state as webui_state
import time
from scripts.spartan.World import World, WorldAlreadyInitialized
from scripts.spartan.UI import UI
from scripts.spartan.shared import logger
from scripts.spartan.control_net import pack_control_net
from modules.processing import fix_seed, Processed
import signal import signal
import sys import sys
import time
from threading import Thread, current_thread
from typing import List
import gradio import gradio
import urllib3
from PIL import Image
from modules import processing
from modules import scripts
from modules.images import save_image
from modules.processing import fix_seed, Processed
from modules.shared import opts, cmd_opts
from modules.shared import state as webui_state
from scripts.spartan.control_net import pack_control_net
from scripts.spartan.shared import logger
from scripts.spartan.ui import UI
from scripts.spartan.world import World, WorldAlreadyInitialized
old_sigint_handler = signal.getsignal(signal.SIGINT) old_sigint_handler = signal.getsignal(signal.SIGINT)
old_sigterm_handler = signal.getsignal(signal.SIGTERM) old_sigterm_handler = signal.getsignal(signal.SIGTERM)
@ -35,7 +35,7 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM)
class Script(scripts.Script): class Script(scripts.Script):
worker_threads: List[Thread] = [] worker_threads: List[Thread] = []
# Whether to verify worker certificates. Can be useful if your remotes are self-signed. # Whether to verify worker certificates. Can be useful if your remotes are self-signed.
verify_remotes = False if cmd_opts.distributed_skip_verify_remotes else True verify_remotes = not cmd_opts.distributed_skip_verify_remotes
is_img2img = True is_img2img = True
is_txt2img = True is_txt2img = True
@ -268,13 +268,13 @@ class Script(scripts.Script):
packed_script_args.append(pack_control_net(cn_units)) packed_script_args.append(pack_control_net(cn_units))
continue continue
else:
# other scripts to pack # other scripts to pack
args_script_pack = {title: {"args": []}} args_script_pack = {title: {"args": []}}
for arg in p.script_args[script.args_from:script.args_to]: for arg in p.script_args[script.args_from:script.args_to]:
args_script_pack[title]["args"].append(arg) args_script_pack[title]["args"].append(arg)
packed_script_args.append(args_script_pack) packed_script_args.append(args_script_pack)
# https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/issues/12#issuecomment-1480382514 # https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/issues/12#issuecomment-1480382514
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works # encapsulating the request object within a txt2imgreq object is deprecated and no longer works
# see test/basic_features/txt2img_test.py for an example # see test/basic_features/txt2img_test.py for an example
@ -379,7 +379,8 @@ class Script(scripts.Script):
def postprocess(p, processed, *args): def postprocess(p, processed, *args):
if not Script.world.enabled: if not Script.world.enabled:
return return
elif len(processed.images) >= 1 and Script.master_start is not None:
if len(processed.images) >= 1 and Script.master_start is not None:
Script.add_to_gallery(p=p, processed=processed) Script.add_to_gallery(p=p, processed=processed)
@staticmethod @staticmethod

View File

@ -1,6 +1,6 @@
from modules.api.api import encode_pil_to_base64
from PIL import Image
import copy import copy
from PIL import Image
from modules.api.api import encode_pil_to_base64
from scripts.spartan.shared import logger from scripts.spartan.shared import logger

View File

@ -33,11 +33,11 @@ class Worker_Model(BaseModel):
password: Optional[str] = Field(description="The password to be used when authenticating with this worker") password: Optional[str] = Field(description="The password to be used when authenticating with this worker")
pixel_cap: Optional[int] = Field(default=-1, description="Max amount of pixels to allow one worker to handle at the same time. -1 means there is no limit") pixel_cap: Optional[int] = Field(default=-1, description="Max amount of pixels to allow one worker to handle at the same time. -1 means there is no limit")
class Config_Model(BaseModel): class ConfigModel(BaseModel):
workers: List[Dict[str, Worker_Model]] workers: List[Dict[str, Worker_Model]]
benchmark_payload: Dict = Field( benchmark_payload: Dict = Field(
default=Benchmark_Payload, default=Benchmark_Payload,
description='the payload used when benchmarking a node' description='the payload used when benchmarking a node'
) )
job_timeout: Optional[int] = Field(default=3) job_timeout: Optional[int] = Field(default=3)
enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True) enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True)

View File

@ -1,18 +1,18 @@
import logging import logging
from inspect import getsourcefile
from logging import Handler from logging import Handler
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
from inspect import getsourcefile from os.path import abspath
from pathlib import Path
from typing import Union from typing import Union
from rich.logging import RichHandler
from modules.shared import cmd_opts from modules.shared import cmd_opts
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from os.path import abspath from rich.logging import RichHandler
from pathlib import Path
extension_path = Path(abspath(getsourcefile(lambda: 0))).parent.parent.parent extension_path = Path(abspath(getsourcefile(lambda: 0))).parent.parent.parent
# https://rich.readthedocs.io/en/stable/logging.html # https://rich.readthedocs.io/en/stable/logging.html
log_level = 'DEBUG' if cmd_opts.distributed_debug else 'INFO' LOG_LEVEL = 'DEBUG' if cmd_opts.distributed_debug else 'INFO'
logger = logging.getLogger("distributed") logger = logging.getLogger("distributed")
rich_handler = RichHandler( rich_handler = RichHandler(
rich_tracebacks=True, rich_tracebacks=True,
@ -21,7 +21,7 @@ rich_handler = RichHandler(
keywords=["distributed", "Distributed", "worker", "Worker", "world", "World"] keywords=["distributed", "Distributed", "worker", "Worker", "world", "World"]
) )
logger.propagate = False # prevent log duplication by webui since it now uses the logging module logger.propagate = False # prevent log duplication by webui since it now uses the logging module
logger.setLevel(log_level) logger.setLevel(LOG_LEVEL)
log_path = extension_path.joinpath('distributed.log') log_path = extension_path.joinpath('distributed.log')
file_handler = RotatingFileHandler(filename=log_path, maxBytes=10_000_000, backupCount=1) file_handler = RotatingFileHandler(filename=log_path, maxBytes=10_000_000, backupCount=1)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')

View File

@ -1,14 +1,12 @@
import os import os
import subprocess import subprocess
from pathlib import Path from pathlib import Path
import gradio
from .shared import logger, log_level, gui_handler
from .Worker import Worker, State
from modules.shared import state as webui_state
from modules.shared import opts
from typing import List
from threading import Thread from threading import Thread
import gradio
from modules.shared import opts
from modules.shared import state as webui_state
from .shared import logger, LOG_LEVEL, gui_handler
from .worker import State
worker_select_dropdown = None worker_select_dropdown = None
@ -174,7 +172,7 @@ class UI:
worker.session.auth = (user, password) worker.session.auth = (user, password)
self.world.save_config() self.world.save_config()
def main_toggle_btn(self, toggle): def main_toggle_btn(self):
self.world.enabled = not self.world.enabled self.world.enabled = not self.world.enabled
self.world.save_config() self.world.save_config()
@ -245,7 +243,7 @@ class UI:
redo_benchmarks_btn.style(full_width=False) redo_benchmarks_btn.style(full_width=False)
redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[]) redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[])
if log_level == 'DEBUG': if LOG_LEVEL == 'DEBUG':
clear_queue_btn = gradio.Button(value='Clear local webui queue', variant='stop') clear_queue_btn = gradio.Button(value='Clear local webui queue', variant='stop')
clear_queue_btn.style(full_width=False) clear_queue_btn.style(full_width=False)
clear_queue_btn.click(self.clear_queue_btn) clear_queue_btn.click(self.clear_queue_btn)

View File

@ -1,19 +1,18 @@
import io
# import gradio
import requests
from typing import List, Tuple, Union
import math
import copy
import time
from threading import Thread
from modules.shared import cmd_opts
from enum import Enum
import json
import base64 import base64
import copy
import io
import json
import math
import queue import queue
from modules.shared import state as master_state
from modules.api.api import encode_pil_to_base64
import re import re
import time
from enum import Enum
from threading import Thread
from typing import List, Union
import requests
from modules.api.api import encode_pil_to_base64
from modules.shared import cmd_opts
from modules.shared import state as master_state
from . import shared as sh from . import shared as sh
from .shared import logger, warmup_samples from .shared import logger, warmup_samples
@ -438,8 +437,7 @@ class Worker:
result = response_queue.get() result = response_queue.get()
if isinstance(result, Exception): if isinstance(result, Exception):
raise result raise result
else: response = result
response = result
self.response = response.json() self.response = response.json()
if response.status_code != 200: if response.status_code != 200:
@ -600,10 +598,7 @@ class Worker:
timeout=3, timeout=3,
verify=not self.verify_remotes verify=not self.verify_remotes
) )
if response.status_code == 200: return response.status_code == 200
return True
else:
return False
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
return False return False

View File

@ -9,15 +9,15 @@ import copy
import json import json
import os import os
import time import time
from typing import List, Dict, Union
from threading import Thread from threading import Thread
from typing import List
import gradio import gradio
from modules.processing import process_images, StableDiffusionProcessingTxt2Img
import modules.shared as shared import modules.shared as shared
from .Worker import Worker, State from modules.processing import process_images, StableDiffusionProcessingTxt2Img
from .shared import logger, warmup_samples, extension_path
from .pmodels import Config_Model, Benchmark_Payload
from . import shared as sh from . import shared as sh
from .pmodels import ConfigModel, Benchmark_Payload
from .shared import logger, warmup_samples, extension_path
from .worker import Worker, State
class NotBenchmarked(Exception): class NotBenchmarked(Exception):
@ -66,9 +66,8 @@ class Job:
if pixels <= self.worker.pixel_cap: if pixels <= self.worker.pixel_cap:
self.batch_size += batch_size self.batch_size += batch_size
return True return True
else: logger.debug(f"worker {self.worker.label} hit pixel cap ({pixels} > cap: {self.worker.pixel_cap})")
logger.debug(f"worker {self.worker.label} hit pixel cap ({pixels} > cap: {self.worker.pixel_cap})") return False
return False
class World: class World:
@ -370,7 +369,7 @@ class World:
master_bench_payload.do_not_save_samples = True master_bench_payload.do_not_save_samples = True
# "warm up" due to initial generation lag # "warm up" due to initial generation lag
for i in range(warmup_samples): for _ in range(warmup_samples):
process_images(master_bench_payload) process_images(master_bench_payload)
# get actual sample # get actual sample
@ -610,7 +609,7 @@ class World:
sh.benchmark_payload = Benchmark_Payload() sh.benchmark_payload = Benchmark_Payload()
return return
config = Config_Model(**config_raw) config = ConfigModel(**config_raw)
# saves config schema to <extension>/distributed-config.schema.json # saves config schema to <extension>/distributed-config.schema.json
# print(models.Config.schema_json()) # print(models.Config.schema_json())
@ -634,7 +633,7 @@ class World:
Saves the config file. Saves the config file.
""" """
config = Config_Model( config = ConfigModel(
workers=[{worker.label: worker.model.dict()} for worker in self._workers], workers=[{worker.label: worker.model.dict()} for worker in self._workers],
benchmark_payload=sh.benchmark_payload, benchmark_payload=sh.benchmark_payload,
job_timeout=self.job_timeout, job_timeout=self.job_timeout,