some formatting
parent
587843f89f
commit
f786860cbc
60
preload.py
60
preload.py
|
|
@ -1,38 +1,38 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from inspect import getsourcefile
|
||||
from os.path import abspath
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument(
|
||||
"--distributed-remotes",
|
||||
nargs="+",
|
||||
help="Enter n pairs of sockets",
|
||||
type=lambda t: t.split(":")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distributed-remotes",
|
||||
nargs="+",
|
||||
help="Enter n pairs of sockets",
|
||||
type=lambda t: t.split(":")
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--distributed-skip-verify-remotes",
|
||||
help="Disable verification of remote worker TLS certificates",
|
||||
action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distributed-skip-verify-remotes",
|
||||
help="Disable verification of remote worker TLS certificates",
|
||||
action="store_true"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--distributed-remotes-autosave",
|
||||
help="Enable auto-saving of remote worker generations",
|
||||
action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distributed-remotes-autosave",
|
||||
help="Enable auto-saving of remote worker generations",
|
||||
action="store_true"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--distributed-debug",
|
||||
help="Enable debug information",
|
||||
action="store_true"
|
||||
)
|
||||
extension_path = Path(abspath(getsourcefile(lambda: 0))).parent
|
||||
config_path = extension_path.joinpath('distributed-config.json')
|
||||
# add config file
|
||||
parser.add_argument(
|
||||
"--distributed-config",
|
||||
help="config file to load / save, default: $WEBUI_PATH/distributed-config.json",
|
||||
default=config_path
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distributed-debug",
|
||||
help="Enable debug information",
|
||||
action="store_true"
|
||||
)
|
||||
extension_path = Path(abspath(getsourcefile(lambda: 0))).parent
|
||||
config_path = extension_path.joinpath('distributed-config.json')
|
||||
# add config file
|
||||
parser.add_argument(
|
||||
"--distributed-config",
|
||||
help="config file to load / save, default: $WEBUI_PATH/distributed-config.json",
|
||||
default=config_path
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,28 +3,28 @@ https://github.com/papuSpartan/stable-diffusion-webui-distributed
|
|||
"""
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from modules import scripts
|
||||
from modules import processing
|
||||
from threading import Thread, current_thread
|
||||
from PIL import Image
|
||||
from typing import List, Callable
|
||||
import urllib3
|
||||
import copy
|
||||
from modules.images import save_image
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.shared import state as webui_state
|
||||
import time
|
||||
from scripts.spartan.World import World, WorldAlreadyInitialized
|
||||
from scripts.spartan.UI import UI
|
||||
from scripts.spartan.shared import logger
|
||||
from scripts.spartan.control_net import pack_control_net
|
||||
from modules.processing import fix_seed, Processed
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from threading import Thread, current_thread
|
||||
from typing import List
|
||||
import gradio
|
||||
import urllib3
|
||||
from PIL import Image
|
||||
from modules import processing
|
||||
from modules import scripts
|
||||
from modules.images import save_image
|
||||
from modules.processing import fix_seed, Processed
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.shared import state as webui_state
|
||||
from scripts.spartan.control_net import pack_control_net
|
||||
from scripts.spartan.shared import logger
|
||||
from scripts.spartan.ui import UI
|
||||
from scripts.spartan.world import World, WorldAlreadyInitialized
|
||||
|
||||
old_sigint_handler = signal.getsignal(signal.SIGINT)
|
||||
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
||||
|
|
@ -35,7 +35,7 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
|||
class Script(scripts.Script):
|
||||
worker_threads: List[Thread] = []
|
||||
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
|
||||
verify_remotes = False if cmd_opts.distributed_skip_verify_remotes else True
|
||||
verify_remotes = not cmd_opts.distributed_skip_verify_remotes
|
||||
|
||||
is_img2img = True
|
||||
is_txt2img = True
|
||||
|
|
@ -268,13 +268,13 @@ class Script(scripts.Script):
|
|||
packed_script_args.append(pack_control_net(cn_units))
|
||||
|
||||
continue
|
||||
else:
|
||||
# other scripts to pack
|
||||
args_script_pack = {title: {"args": []}}
|
||||
for arg in p.script_args[script.args_from:script.args_to]:
|
||||
args_script_pack[title]["args"].append(arg)
|
||||
packed_script_args.append(args_script_pack)
|
||||
# https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/issues/12#issuecomment-1480382514
|
||||
|
||||
# other scripts to pack
|
||||
args_script_pack = {title: {"args": []}}
|
||||
for arg in p.script_args[script.args_from:script.args_to]:
|
||||
args_script_pack[title]["args"].append(arg)
|
||||
packed_script_args.append(args_script_pack)
|
||||
# https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/issues/12#issuecomment-1480382514
|
||||
|
||||
# encapsulating the request object within a txt2imgreq object is deprecated and no longer works
|
||||
# see test/basic_features/txt2img_test.py for an example
|
||||
|
|
@ -379,7 +379,8 @@ class Script(scripts.Script):
|
|||
def postprocess(p, processed, *args):
|
||||
if not Script.world.enabled:
|
||||
return
|
||||
elif len(processed.images) >= 1 and Script.master_start is not None:
|
||||
|
||||
if len(processed.images) >= 1 and Script.master_start is not None:
|
||||
Script.add_to_gallery(p=p, processed=processed)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from modules.api.api import encode_pil_to_base64
|
||||
from PIL import Image
|
||||
import copy
|
||||
from PIL import Image
|
||||
from modules.api.api import encode_pil_to_base64
|
||||
from scripts.spartan.shared import logger
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class Worker_Model(BaseModel):
|
|||
password: Optional[str] = Field(description="The password to be used when authenticating with this worker")
|
||||
pixel_cap: Optional[int] = Field(default=-1, description="Max amount of pixels to allow one worker to handle at the same time. -1 means there is no limit")
|
||||
|
||||
class Config_Model(BaseModel):
|
||||
class ConfigModel(BaseModel):
|
||||
workers: List[Dict[str, Worker_Model]]
|
||||
benchmark_payload: Dict = Field(
|
||||
default=Benchmark_Payload,
|
||||
|
|
|
|||
|
|
@ -1,18 +1,18 @@
|
|||
import logging
|
||||
from inspect import getsourcefile
|
||||
from logging import Handler
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from inspect import getsourcefile
|
||||
from os.path import abspath
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from rich.logging import RichHandler
|
||||
from modules.shared import cmd_opts
|
||||
from pydantic import BaseModel, Field
|
||||
from os.path import abspath
|
||||
from rich.logging import RichHandler
|
||||
|
||||
from pathlib import Path
|
||||
extension_path = Path(abspath(getsourcefile(lambda: 0))).parent.parent.parent
|
||||
|
||||
# https://rich.readthedocs.io/en/stable/logging.html
|
||||
log_level = 'DEBUG' if cmd_opts.distributed_debug else 'INFO'
|
||||
LOG_LEVEL = 'DEBUG' if cmd_opts.distributed_debug else 'INFO'
|
||||
logger = logging.getLogger("distributed")
|
||||
rich_handler = RichHandler(
|
||||
rich_tracebacks=True,
|
||||
|
|
@ -21,7 +21,7 @@ rich_handler = RichHandler(
|
|||
keywords=["distributed", "Distributed", "worker", "Worker", "world", "World"]
|
||||
)
|
||||
logger.propagate = False # prevent log duplication by webui since it now uses the logging module
|
||||
logger.setLevel(log_level)
|
||||
logger.setLevel(LOG_LEVEL)
|
||||
log_path = extension_path.joinpath('distributed.log')
|
||||
file_handler = RotatingFileHandler(filename=log_path, maxBytes=10_000_000, backupCount=1)
|
||||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
|
|
|||
|
|
@ -1,14 +1,12 @@
|
|||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import gradio
|
||||
|
||||
from .shared import logger, log_level, gui_handler
|
||||
from .Worker import Worker, State
|
||||
from modules.shared import state as webui_state
|
||||
from modules.shared import opts
|
||||
from typing import List
|
||||
from threading import Thread
|
||||
import gradio
|
||||
from modules.shared import opts
|
||||
from modules.shared import state as webui_state
|
||||
from .shared import logger, LOG_LEVEL, gui_handler
|
||||
from .worker import State
|
||||
|
||||
worker_select_dropdown = None
|
||||
|
||||
|
|
@ -174,7 +172,7 @@ class UI:
|
|||
worker.session.auth = (user, password)
|
||||
self.world.save_config()
|
||||
|
||||
def main_toggle_btn(self, toggle):
|
||||
def main_toggle_btn(self):
|
||||
self.world.enabled = not self.world.enabled
|
||||
self.world.save_config()
|
||||
|
||||
|
|
@ -245,7 +243,7 @@ class UI:
|
|||
redo_benchmarks_btn.style(full_width=False)
|
||||
redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[])
|
||||
|
||||
if log_level == 'DEBUG':
|
||||
if LOG_LEVEL == 'DEBUG':
|
||||
clear_queue_btn = gradio.Button(value='Clear local webui queue', variant='stop')
|
||||
clear_queue_btn.style(full_width=False)
|
||||
clear_queue_btn.click(self.clear_queue_btn)
|
||||
|
|
@ -1,19 +1,18 @@
|
|||
import io
|
||||
# import gradio
|
||||
import requests
|
||||
from typing import List, Tuple, Union
|
||||
import math
|
||||
import copy
|
||||
import time
|
||||
from threading import Thread
|
||||
from modules.shared import cmd_opts
|
||||
from enum import Enum
|
||||
import json
|
||||
import base64
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import queue
|
||||
from modules.shared import state as master_state
|
||||
from modules.api.api import encode_pil_to_base64
|
||||
import re
|
||||
import time
|
||||
from enum import Enum
|
||||
from threading import Thread
|
||||
from typing import List, Union
|
||||
import requests
|
||||
from modules.api.api import encode_pil_to_base64
|
||||
from modules.shared import cmd_opts
|
||||
from modules.shared import state as master_state
|
||||
from . import shared as sh
|
||||
from .shared import logger, warmup_samples
|
||||
|
||||
|
|
@ -438,8 +437,7 @@ class Worker:
|
|||
result = response_queue.get()
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
else:
|
||||
response = result
|
||||
response = result
|
||||
|
||||
self.response = response.json()
|
||||
if response.status_code != 200:
|
||||
|
|
@ -600,10 +598,7 @@ class Worker:
|
|||
timeout=3,
|
||||
verify=not self.verify_remotes
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return response.status_code == 200
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
return False
|
||||
|
|
@ -9,15 +9,15 @@ import copy
|
|||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import List, Dict, Union
|
||||
from threading import Thread
|
||||
from typing import List
|
||||
import gradio
|
||||
from modules.processing import process_images, StableDiffusionProcessingTxt2Img
|
||||
import modules.shared as shared
|
||||
from .Worker import Worker, State
|
||||
from .shared import logger, warmup_samples, extension_path
|
||||
from .pmodels import Config_Model, Benchmark_Payload
|
||||
from modules.processing import process_images, StableDiffusionProcessingTxt2Img
|
||||
from . import shared as sh
|
||||
from .pmodels import ConfigModel, Benchmark_Payload
|
||||
from .shared import logger, warmup_samples, extension_path
|
||||
from .worker import Worker, State
|
||||
|
||||
|
||||
class NotBenchmarked(Exception):
|
||||
|
|
@ -66,9 +66,8 @@ class Job:
|
|||
if pixels <= self.worker.pixel_cap:
|
||||
self.batch_size += batch_size
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"worker {self.worker.label} hit pixel cap ({pixels} > cap: {self.worker.pixel_cap})")
|
||||
return False
|
||||
logger.debug(f"worker {self.worker.label} hit pixel cap ({pixels} > cap: {self.worker.pixel_cap})")
|
||||
return False
|
||||
|
||||
|
||||
class World:
|
||||
|
|
@ -370,7 +369,7 @@ class World:
|
|||
master_bench_payload.do_not_save_samples = True
|
||||
|
||||
# "warm up" due to initial generation lag
|
||||
for i in range(warmup_samples):
|
||||
for _ in range(warmup_samples):
|
||||
process_images(master_bench_payload)
|
||||
|
||||
# get actual sample
|
||||
|
|
@ -610,7 +609,7 @@ class World:
|
|||
sh.benchmark_payload = Benchmark_Payload()
|
||||
return
|
||||
|
||||
config = Config_Model(**config_raw)
|
||||
config = ConfigModel(**config_raw)
|
||||
|
||||
# saves config schema to <extension>/distributed-config.schema.json
|
||||
# print(models.Config.schema_json())
|
||||
|
|
@ -634,7 +633,7 @@ class World:
|
|||
Saves the config file.
|
||||
"""
|
||||
|
||||
config = Config_Model(
|
||||
config = ConfigModel(
|
||||
workers=[{worker.label: worker.model.dict()} for worker in self._workers],
|
||||
benchmark_payload=sh.benchmark_payload,
|
||||
job_timeout=self.job_timeout,
|
||||
Loading…
Reference in New Issue