some formatting
parent
587843f89f
commit
f786860cbc
|
|
@ -1,8 +1,8 @@
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from inspect import getsourcefile
|
from inspect import getsourcefile
|
||||||
from os.path import abspath
|
from os.path import abspath
|
||||||
|
|
||||||
|
|
||||||
def preload(parser):
|
def preload(parser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--distributed-remotes",
|
"--distributed-remotes",
|
||||||
|
|
|
||||||
|
|
@ -3,28 +3,28 @@ https://github.com/papuSpartan/stable-diffusion-webui-distributed
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
import copy
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from modules import scripts
|
|
||||||
from modules import processing
|
|
||||||
from threading import Thread, current_thread
|
|
||||||
from PIL import Image
|
|
||||||
from typing import List, Callable
|
|
||||||
import urllib3
|
|
||||||
import copy
|
|
||||||
from modules.images import save_image
|
|
||||||
from modules.shared import opts, cmd_opts
|
|
||||||
from modules.shared import state as webui_state
|
|
||||||
import time
|
|
||||||
from scripts.spartan.World import World, WorldAlreadyInitialized
|
|
||||||
from scripts.spartan.UI import UI
|
|
||||||
from scripts.spartan.shared import logger
|
|
||||||
from scripts.spartan.control_net import pack_control_net
|
|
||||||
from modules.processing import fix_seed, Processed
|
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
|
from threading import Thread, current_thread
|
||||||
|
from typing import List
|
||||||
import gradio
|
import gradio
|
||||||
|
import urllib3
|
||||||
|
from PIL import Image
|
||||||
|
from modules import processing
|
||||||
|
from modules import scripts
|
||||||
|
from modules.images import save_image
|
||||||
|
from modules.processing import fix_seed, Processed
|
||||||
|
from modules.shared import opts, cmd_opts
|
||||||
|
from modules.shared import state as webui_state
|
||||||
|
from scripts.spartan.control_net import pack_control_net
|
||||||
|
from scripts.spartan.shared import logger
|
||||||
|
from scripts.spartan.ui import UI
|
||||||
|
from scripts.spartan.world import World, WorldAlreadyInitialized
|
||||||
|
|
||||||
old_sigint_handler = signal.getsignal(signal.SIGINT)
|
old_sigint_handler = signal.getsignal(signal.SIGINT)
|
||||||
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
||||||
|
|
@ -35,7 +35,7 @@ old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
||||||
class Script(scripts.Script):
|
class Script(scripts.Script):
|
||||||
worker_threads: List[Thread] = []
|
worker_threads: List[Thread] = []
|
||||||
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
|
# Whether to verify worker certificates. Can be useful if your remotes are self-signed.
|
||||||
verify_remotes = False if cmd_opts.distributed_skip_verify_remotes else True
|
verify_remotes = not cmd_opts.distributed_skip_verify_remotes
|
||||||
|
|
||||||
is_img2img = True
|
is_img2img = True
|
||||||
is_txt2img = True
|
is_txt2img = True
|
||||||
|
|
@ -268,7 +268,7 @@ class Script(scripts.Script):
|
||||||
packed_script_args.append(pack_control_net(cn_units))
|
packed_script_args.append(pack_control_net(cn_units))
|
||||||
|
|
||||||
continue
|
continue
|
||||||
else:
|
|
||||||
# other scripts to pack
|
# other scripts to pack
|
||||||
args_script_pack = {title: {"args": []}}
|
args_script_pack = {title: {"args": []}}
|
||||||
for arg in p.script_args[script.args_from:script.args_to]:
|
for arg in p.script_args[script.args_from:script.args_to]:
|
||||||
|
|
@ -379,7 +379,8 @@ class Script(scripts.Script):
|
||||||
def postprocess(p, processed, *args):
|
def postprocess(p, processed, *args):
|
||||||
if not Script.world.enabled:
|
if not Script.world.enabled:
|
||||||
return
|
return
|
||||||
elif len(processed.images) >= 1 and Script.master_start is not None:
|
|
||||||
|
if len(processed.images) >= 1 and Script.master_start is not None:
|
||||||
Script.add_to_gallery(p=p, processed=processed)
|
Script.add_to_gallery(p=p, processed=processed)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from modules.api.api import encode_pil_to_base64
|
|
||||||
from PIL import Image
|
|
||||||
import copy
|
import copy
|
||||||
|
from PIL import Image
|
||||||
|
from modules.api.api import encode_pil_to_base64
|
||||||
from scripts.spartan.shared import logger
|
from scripts.spartan.shared import logger
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ class Worker_Model(BaseModel):
|
||||||
password: Optional[str] = Field(description="The password to be used when authenticating with this worker")
|
password: Optional[str] = Field(description="The password to be used when authenticating with this worker")
|
||||||
pixel_cap: Optional[int] = Field(default=-1, description="Max amount of pixels to allow one worker to handle at the same time. -1 means there is no limit")
|
pixel_cap: Optional[int] = Field(default=-1, description="Max amount of pixels to allow one worker to handle at the same time. -1 means there is no limit")
|
||||||
|
|
||||||
class Config_Model(BaseModel):
|
class ConfigModel(BaseModel):
|
||||||
workers: List[Dict[str, Worker_Model]]
|
workers: List[Dict[str, Worker_Model]]
|
||||||
benchmark_payload: Dict = Field(
|
benchmark_payload: Dict = Field(
|
||||||
default=Benchmark_Payload,
|
default=Benchmark_Payload,
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,18 @@
|
||||||
import logging
|
import logging
|
||||||
|
from inspect import getsourcefile
|
||||||
from logging import Handler
|
from logging import Handler
|
||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import RotatingFileHandler
|
||||||
from inspect import getsourcefile
|
from os.path import abspath
|
||||||
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from rich.logging import RichHandler
|
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from os.path import abspath
|
from rich.logging import RichHandler
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
extension_path = Path(abspath(getsourcefile(lambda: 0))).parent.parent.parent
|
extension_path = Path(abspath(getsourcefile(lambda: 0))).parent.parent.parent
|
||||||
|
|
||||||
# https://rich.readthedocs.io/en/stable/logging.html
|
# https://rich.readthedocs.io/en/stable/logging.html
|
||||||
log_level = 'DEBUG' if cmd_opts.distributed_debug else 'INFO'
|
LOG_LEVEL = 'DEBUG' if cmd_opts.distributed_debug else 'INFO'
|
||||||
logger = logging.getLogger("distributed")
|
logger = logging.getLogger("distributed")
|
||||||
rich_handler = RichHandler(
|
rich_handler = RichHandler(
|
||||||
rich_tracebacks=True,
|
rich_tracebacks=True,
|
||||||
|
|
@ -21,7 +21,7 @@ rich_handler = RichHandler(
|
||||||
keywords=["distributed", "Distributed", "worker", "Worker", "world", "World"]
|
keywords=["distributed", "Distributed", "worker", "Worker", "world", "World"]
|
||||||
)
|
)
|
||||||
logger.propagate = False # prevent log duplication by webui since it now uses the logging module
|
logger.propagate = False # prevent log duplication by webui since it now uses the logging module
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(LOG_LEVEL)
|
||||||
log_path = extension_path.joinpath('distributed.log')
|
log_path = extension_path.joinpath('distributed.log')
|
||||||
file_handler = RotatingFileHandler(filename=log_path, maxBytes=10_000_000, backupCount=1)
|
file_handler = RotatingFileHandler(filename=log_path, maxBytes=10_000_000, backupCount=1)
|
||||||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,12 @@
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import gradio
|
|
||||||
|
|
||||||
from .shared import logger, log_level, gui_handler
|
|
||||||
from .Worker import Worker, State
|
|
||||||
from modules.shared import state as webui_state
|
|
||||||
from modules.shared import opts
|
|
||||||
from typing import List
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
import gradio
|
||||||
|
from modules.shared import opts
|
||||||
|
from modules.shared import state as webui_state
|
||||||
|
from .shared import logger, LOG_LEVEL, gui_handler
|
||||||
|
from .worker import State
|
||||||
|
|
||||||
worker_select_dropdown = None
|
worker_select_dropdown = None
|
||||||
|
|
||||||
|
|
@ -174,7 +172,7 @@ class UI:
|
||||||
worker.session.auth = (user, password)
|
worker.session.auth = (user, password)
|
||||||
self.world.save_config()
|
self.world.save_config()
|
||||||
|
|
||||||
def main_toggle_btn(self, toggle):
|
def main_toggle_btn(self):
|
||||||
self.world.enabled = not self.world.enabled
|
self.world.enabled = not self.world.enabled
|
||||||
self.world.save_config()
|
self.world.save_config()
|
||||||
|
|
||||||
|
|
@ -245,7 +243,7 @@ class UI:
|
||||||
redo_benchmarks_btn.style(full_width=False)
|
redo_benchmarks_btn.style(full_width=False)
|
||||||
redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[])
|
redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[])
|
||||||
|
|
||||||
if log_level == 'DEBUG':
|
if LOG_LEVEL == 'DEBUG':
|
||||||
clear_queue_btn = gradio.Button(value='Clear local webui queue', variant='stop')
|
clear_queue_btn = gradio.Button(value='Clear local webui queue', variant='stop')
|
||||||
clear_queue_btn.style(full_width=False)
|
clear_queue_btn.style(full_width=False)
|
||||||
clear_queue_btn.click(self.clear_queue_btn)
|
clear_queue_btn.click(self.clear_queue_btn)
|
||||||
|
|
@ -1,19 +1,18 @@
|
||||||
import io
|
|
||||||
# import gradio
|
|
||||||
import requests
|
|
||||||
from typing import List, Tuple, Union
|
|
||||||
import math
|
|
||||||
import copy
|
|
||||||
import time
|
|
||||||
from threading import Thread
|
|
||||||
from modules.shared import cmd_opts
|
|
||||||
from enum import Enum
|
|
||||||
import json
|
|
||||||
import base64
|
import base64
|
||||||
|
import copy
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import math
|
||||||
import queue
|
import queue
|
||||||
from modules.shared import state as master_state
|
|
||||||
from modules.api.api import encode_pil_to_base64
|
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
|
from enum import Enum
|
||||||
|
from threading import Thread
|
||||||
|
from typing import List, Union
|
||||||
|
import requests
|
||||||
|
from modules.api.api import encode_pil_to_base64
|
||||||
|
from modules.shared import cmd_opts
|
||||||
|
from modules.shared import state as master_state
|
||||||
from . import shared as sh
|
from . import shared as sh
|
||||||
from .shared import logger, warmup_samples
|
from .shared import logger, warmup_samples
|
||||||
|
|
||||||
|
|
@ -438,7 +437,6 @@ class Worker:
|
||||||
result = response_queue.get()
|
result = response_queue.get()
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
raise result
|
raise result
|
||||||
else:
|
|
||||||
response = result
|
response = result
|
||||||
|
|
||||||
self.response = response.json()
|
self.response = response.json()
|
||||||
|
|
@ -600,10 +598,7 @@ class Worker:
|
||||||
timeout=3,
|
timeout=3,
|
||||||
verify=not self.verify_remotes
|
verify=not self.verify_remotes
|
||||||
)
|
)
|
||||||
if response.status_code == 200:
|
return response.status_code == 200
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
return False
|
return False
|
||||||
|
|
@ -9,15 +9,15 @@ import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import List, Dict, Union
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
from typing import List
|
||||||
import gradio
|
import gradio
|
||||||
from modules.processing import process_images, StableDiffusionProcessingTxt2Img
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from .Worker import Worker, State
|
from modules.processing import process_images, StableDiffusionProcessingTxt2Img
|
||||||
from .shared import logger, warmup_samples, extension_path
|
|
||||||
from .pmodels import Config_Model, Benchmark_Payload
|
|
||||||
from . import shared as sh
|
from . import shared as sh
|
||||||
|
from .pmodels import ConfigModel, Benchmark_Payload
|
||||||
|
from .shared import logger, warmup_samples, extension_path
|
||||||
|
from .worker import Worker, State
|
||||||
|
|
||||||
|
|
||||||
class NotBenchmarked(Exception):
|
class NotBenchmarked(Exception):
|
||||||
|
|
@ -66,7 +66,6 @@ class Job:
|
||||||
if pixels <= self.worker.pixel_cap:
|
if pixels <= self.worker.pixel_cap:
|
||||||
self.batch_size += batch_size
|
self.batch_size += batch_size
|
||||||
return True
|
return True
|
||||||
else:
|
|
||||||
logger.debug(f"worker {self.worker.label} hit pixel cap ({pixels} > cap: {self.worker.pixel_cap})")
|
logger.debug(f"worker {self.worker.label} hit pixel cap ({pixels} > cap: {self.worker.pixel_cap})")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -370,7 +369,7 @@ class World:
|
||||||
master_bench_payload.do_not_save_samples = True
|
master_bench_payload.do_not_save_samples = True
|
||||||
|
|
||||||
# "warm up" due to initial generation lag
|
# "warm up" due to initial generation lag
|
||||||
for i in range(warmup_samples):
|
for _ in range(warmup_samples):
|
||||||
process_images(master_bench_payload)
|
process_images(master_bench_payload)
|
||||||
|
|
||||||
# get actual sample
|
# get actual sample
|
||||||
|
|
@ -610,7 +609,7 @@ class World:
|
||||||
sh.benchmark_payload = Benchmark_Payload()
|
sh.benchmark_payload = Benchmark_Payload()
|
||||||
return
|
return
|
||||||
|
|
||||||
config = Config_Model(**config_raw)
|
config = ConfigModel(**config_raw)
|
||||||
|
|
||||||
# saves config schema to <extension>/distributed-config.schema.json
|
# saves config schema to <extension>/distributed-config.schema.json
|
||||||
# print(models.Config.schema_json())
|
# print(models.Config.schema_json())
|
||||||
|
|
@ -634,7 +633,7 @@ class World:
|
||||||
Saves the config file.
|
Saves the config file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config = Config_Model(
|
config = ConfigModel(
|
||||||
workers=[{worker.label: worker.model.dict()} for worker in self._workers],
|
workers=[{worker.label: worker.model.dict()} for worker in self._workers],
|
||||||
benchmark_payload=sh.benchmark_payload,
|
benchmark_payload=sh.benchmark_payload,
|
||||||
job_timeout=self.job_timeout,
|
job_timeout=self.job_timeout,
|
||||||
Loading…
Reference in New Issue