refactoring and cleanup
parent
c4abe4d1f1
commit
de18217aaf
|
|
@ -24,13 +24,12 @@ from modules.shared import state as webui_state
|
|||
from scripts.spartan.control_net import pack_control_net
|
||||
from scripts.spartan.shared import logger
|
||||
from scripts.spartan.ui import UI
|
||||
from scripts.spartan.world import World, WorldAlreadyInitialized
|
||||
from scripts.spartan.world import World
|
||||
|
||||
old_sigint_handler = signal.getsignal(signal.SIGINT)
|
||||
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
||||
|
||||
|
||||
# TODO implement advertisement of some sort in sdwui api to allow extension to automatically discover workers?
|
||||
# noinspection PyMissingOrEmptyDocstring
|
||||
class Script(scripts.Script):
|
||||
worker_threads: List[Thread] = []
|
||||
|
|
@ -50,17 +49,8 @@ class Script(scripts.Script):
|
|||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
# build world
|
||||
world = World(initial_payload=None, verify_remotes=verify_remotes)
|
||||
# add workers to the world
|
||||
world = World(verify_remotes=verify_remotes)
|
||||
world.load_config()
|
||||
if cmd_opts.distributed_remotes is not None and len(cmd_opts.distributed_remotes) > 0:
|
||||
logger.warning(f"--distributed-remotes is deprecated and may be removed in the future\n"
|
||||
f"gui/external modification of {world.config_path} will be prioritized going forward")
|
||||
|
||||
for worker in cmd_opts.distributed_remotes:
|
||||
world.add_worker(uuid=worker[0], address=worker[1], port=worker[2], tls=False)
|
||||
world.save_config()
|
||||
# do an early check to see which workers are online
|
||||
logger.info("doing initial ping sweep to see which workers are reachable")
|
||||
world.ping_remotes(indiscriminate=True)
|
||||
|
||||
|
|
@ -71,7 +61,7 @@ class Script(scripts.Script):
|
|||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
self.world.load_config()
|
||||
Script.world.load_config()
|
||||
extension_ui = UI(script=Script, world=Script.world)
|
||||
# root, api_exposed = extension_ui.create_ui()
|
||||
components = extension_ui.create_ui()
|
||||
|
|
@ -128,7 +118,7 @@ class Script(scripts.Script):
|
|||
if p.n_iter > 1: # if splitting by batch count
|
||||
num_remote_images *= p.n_iter - 1
|
||||
|
||||
logger.debug(f"image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, "
|
||||
logger.debug(f"image {true_image_pos + 1}/{Script.world.p.batch_size * p.n_iter}, "
|
||||
f"info-index: {info_index}")
|
||||
|
||||
if Script.world.thin_client_mode:
|
||||
|
|
@ -225,30 +215,15 @@ class Script(scripts.Script):
|
|||
p.batch_size = len(processed.images)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def initialize(initial_payload):
|
||||
# get default batch size
|
||||
try:
|
||||
batch_size = initial_payload.batch_size
|
||||
except AttributeError:
|
||||
batch_size = 1
|
||||
|
||||
try:
|
||||
Script.world.initialize(batch_size)
|
||||
Script.world.initial_payload = initial_payload
|
||||
logger.debug(f"World initialized!")
|
||||
except WorldAlreadyInitialized:
|
||||
Script.world.update_world(total_batch_size=batch_size)
|
||||
|
||||
# p's type is
|
||||
# "modules.processing.StableDiffusionProcessingTxt2Img"
|
||||
def before_process(self, p, *args):
|
||||
if not self.world.enabled:
|
||||
# "modules.processing.StableDiffusionProcessing*"
|
||||
@staticmethod
|
||||
def before_process(p, *args):
|
||||
if not Script.world.enabled:
|
||||
logger.debug("extension is disabled")
|
||||
return
|
||||
|
||||
current_thread().name = "distributed_main"
|
||||
Script.initialize(initial_payload=p)
|
||||
Script.world.update(p)
|
||||
|
||||
# save original process_images_inner function for later if we monkeypatch it
|
||||
Script.original_process_images_inner = processing.process_images_inner
|
||||
|
|
@ -367,22 +342,12 @@ class Script(scripts.Script):
|
|||
started_jobs.append(job)
|
||||
|
||||
# if master batch size was changed again due to optimization change it to the updated value
|
||||
if not self.world.thin_client_mode:
|
||||
if not Script.world.thin_client_mode:
|
||||
p.batch_size = Script.world.master_job().batch_size
|
||||
Script.master_start = time.time()
|
||||
|
||||
# generate images assigned to local machine
|
||||
p.do_not_save_grid = True # don't generate grid from master as we are doing this later.
|
||||
if Script.world.thin_client_mode:
|
||||
p.batch_size = 0
|
||||
processed = Processed(p=p, images_list=[])
|
||||
processed.all_prompts = []
|
||||
processed.all_seeds = []
|
||||
processed.all_subseeds = []
|
||||
processed.all_negative_prompts = []
|
||||
processed.infotexts = []
|
||||
processed.prompt = None
|
||||
|
||||
Script.runs_since_init += 1
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import copy
|
||||
import io
|
||||
|
|
|
|||
|
|
@ -30,13 +30,6 @@ class NotBenchmarked(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class WorldAlreadyInitialized(Exception):
|
||||
"""
|
||||
Raised when attempting to initialize the World when it has already been initialized.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Job:
|
||||
"""
|
||||
Keeps track of how much work a given worker should handle.
|
||||
|
|
@ -78,7 +71,7 @@ class World:
|
|||
The frame or "world" which holds all workers (including the local machine).
|
||||
|
||||
Args:
|
||||
initial_payload: The original txt2img payload created by the user initiating the generation request on master.
|
||||
p: The original processing state object created by the user initiating the generation request on master.
|
||||
verify_remotes (bool): Whether to validate remote worker certificates.
|
||||
"""
|
||||
|
||||
|
|
@ -86,15 +79,14 @@ class World:
|
|||
config_path = shared.cmd_opts.distributed_config
|
||||
old_config_path = worker_info_path = extension_path.joinpath('workers.json')
|
||||
|
||||
def __init__(self, initial_payload, verify_remotes: bool = True):
|
||||
def __init__(self, verify_remotes: bool = True):
|
||||
self.p = None
|
||||
self.master_worker = Worker(master=True)
|
||||
self.total_batch_size: int = 0
|
||||
self._workers: List[Worker] = [self.master_worker]
|
||||
self.jobs: List[Job] = []
|
||||
self.job_timeout: int = 3 # seconds
|
||||
self.initialized: bool = False
|
||||
self.verify_remotes = verify_remotes
|
||||
self.initial_payload = copy.copy(initial_payload)
|
||||
self.thin_client_mode = False
|
||||
self.enabled = True
|
||||
self.is_dropdown_handler_injected = False
|
||||
|
|
@ -109,32 +101,12 @@ class World:
|
|||
def __repr__(self):
|
||||
return f"{len(self._workers)} workers"
|
||||
|
||||
def update_world(self, total_batch_size):
|
||||
"""
|
||||
Updates the world with information vital to handling the local generation request after
|
||||
the world has already been initialized.
|
||||
|
||||
Args:
|
||||
total_batch_size (int): The total number of images requested by the local/master sdwui instance.
|
||||
"""
|
||||
|
||||
self.total_batch_size = total_batch_size
|
||||
self.update_jobs()
|
||||
|
||||
def initialize(self, total_batch_size):
|
||||
"""should be called before a world instance is used for anything"""
|
||||
if self.initialized:
|
||||
raise WorldAlreadyInitialized("This world instance was already initialized")
|
||||
|
||||
self.benchmark()
|
||||
self.update_world(total_batch_size=total_batch_size)
|
||||
self.initialized = True
|
||||
|
||||
def default_batch_size(self) -> int:
|
||||
"""the amount of images/total images requested that a worker would compute if conditions were perfect and
|
||||
each worker generated at the same speed. assumes one batch only"""
|
||||
|
||||
return self.total_batch_size // self.size()
|
||||
return self.p.batch_size // self.size()
|
||||
|
||||
def size(self) -> int:
|
||||
"""
|
||||
|
|
@ -396,6 +368,18 @@ class World:
|
|||
|
||||
self.jobs.append(Job(worker=worker, batch_size=batch_size))
|
||||
|
||||
def update(self, p):
|
||||
"""preps world for another run"""
|
||||
if not self.initialized:
|
||||
self.benchmark()
|
||||
self.initialized = True
|
||||
logger.debug("world initialized!")
|
||||
else:
|
||||
logger.debug("world was already initialized")
|
||||
|
||||
self.p = p
|
||||
self.update_jobs()
|
||||
|
||||
def get_workers(self):
|
||||
filtered: List[Worker] = []
|
||||
for worker in self._workers:
|
||||
|
|
@ -431,7 +415,7 @@ class World:
|
|||
|
||||
logger.debug(f"worker '{job.worker.label}' would stall the image gallery by ~{lag:.2f}s\n")
|
||||
job.complementary = True
|
||||
if deferred_images + images_checked + payload['batch_size'] > self.total_batch_size:
|
||||
if deferred_images + images_checked + payload['batch_size'] > self.p.batch_size:
|
||||
logger.debug(f"would go over actual requested size")
|
||||
else:
|
||||
deferred_images += payload['batch_size']
|
||||
|
|
@ -474,9 +458,9 @@ class World:
|
|||
#######################
|
||||
|
||||
# when total number of requested images was not cleanly divisible by world size then we tack the remainder on
|
||||
remainder_images = self.total_batch_size - self.get_current_output_size()
|
||||
remainder_images = self.p.batch_size - self.get_current_output_size()
|
||||
if remainder_images >= 1:
|
||||
logger.debug(f"The requested number of images({self.total_batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed")
|
||||
logger.debug(f"The requested number of images({self.p.batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed")
|
||||
|
||||
realtime_jobs = self.realtime_jobs()
|
||||
realtime_jobs.sort(key=lambda x: x.batch_size)
|
||||
|
|
@ -555,9 +539,9 @@ class World:
|
|||
|
||||
iterations = payload['n_iter']
|
||||
num_returning = self.get_current_output_size()
|
||||
num_complementary = num_returning - self.total_batch_size
|
||||
num_complementary = num_returning - self.p.batch_size
|
||||
distro_summary = "Job distribution:\n"
|
||||
distro_summary += f"{self.total_batch_size} * {iterations} iteration(s)"
|
||||
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
|
||||
if num_complementary > 0:
|
||||
distro_summary += f" + {num_complementary} complementary"
|
||||
distro_summary += f": {num_returning} images total\n"
|
||||
|
|
@ -577,7 +561,7 @@ class World:
|
|||
processed.infotexts = []
|
||||
processed.prompt = None
|
||||
|
||||
self.initial_payload.scripts.postprocess(p, processed)
|
||||
self.p.scripts.postprocess(p, processed)
|
||||
return processed
|
||||
processing.process_images_inner = process_images_inner_bypass
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue