refactoring and cleanup

master
unknown 2024-05-09 00:36:37 -05:00
parent c4abe4d1f1
commit de18217aaf
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
3 changed files with 32 additions and 84 deletions

View File

@ -24,13 +24,12 @@ from modules.shared import state as webui_state
from scripts.spartan.control_net import pack_control_net
from scripts.spartan.shared import logger
from scripts.spartan.ui import UI
from scripts.spartan.world import World, WorldAlreadyInitialized
from scripts.spartan.world import World
old_sigint_handler = signal.getsignal(signal.SIGINT)
old_sigterm_handler = signal.getsignal(signal.SIGTERM)
# TODO implement advertisement of some sort in sdwui api to allow extension to automatically discover workers?
# noinspection PyMissingOrEmptyDocstring
class Script(scripts.Script):
worker_threads: List[Thread] = []
@ -50,17 +49,8 @@ class Script(scripts.Script):
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# build world
world = World(initial_payload=None, verify_remotes=verify_remotes)
# add workers to the world
world = World(verify_remotes=verify_remotes)
world.load_config()
if cmd_opts.distributed_remotes is not None and len(cmd_opts.distributed_remotes) > 0:
logger.warning(f"--distributed-remotes is deprecated and may be removed in the future\n"
f"gui/external modification of {world.config_path} will be prioritized going forward")
for worker in cmd_opts.distributed_remotes:
world.add_worker(uuid=worker[0], address=worker[1], port=worker[2], tls=False)
world.save_config()
# do an early check to see which workers are online
logger.info("doing initial ping sweep to see which workers are reachable")
world.ping_remotes(indiscriminate=True)
@ -71,7 +61,7 @@ class Script(scripts.Script):
return scripts.AlwaysVisible
def ui(self, is_img2img):
self.world.load_config()
Script.world.load_config()
extension_ui = UI(script=Script, world=Script.world)
# root, api_exposed = extension_ui.create_ui()
components = extension_ui.create_ui()
@ -128,7 +118,7 @@ class Script(scripts.Script):
if p.n_iter > 1: # if splitting by batch count
num_remote_images *= p.n_iter - 1
logger.debug(f"image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, "
logger.debug(f"image {true_image_pos + 1}/{Script.world.p.batch_size * p.n_iter}, "
f"info-index: {info_index}")
if Script.world.thin_client_mode:
@ -225,30 +215,15 @@ class Script(scripts.Script):
p.batch_size = len(processed.images)
return
@staticmethod
def initialize(initial_payload):
# get default batch size
try:
batch_size = initial_payload.batch_size
except AttributeError:
batch_size = 1
try:
Script.world.initialize(batch_size)
Script.world.initial_payload = initial_payload
logger.debug(f"World initialized!")
except WorldAlreadyInitialized:
Script.world.update_world(total_batch_size=batch_size)
# p's type is
# "modules.processing.StableDiffusionProcessingTxt2Img"
def before_process(self, p, *args):
if not self.world.enabled:
# "modules.processing.StableDiffusionProcessing*"
@staticmethod
def before_process(p, *args):
if not Script.world.enabled:
logger.debug("extension is disabled")
return
current_thread().name = "distributed_main"
Script.initialize(initial_payload=p)
Script.world.update(p)
# save original process_images_inner function for later if we monkeypatch it
Script.original_process_images_inner = processing.process_images_inner
@ -367,22 +342,12 @@ class Script(scripts.Script):
started_jobs.append(job)
# if master batch size was changed again due to optimization change it to the updated value
if not self.world.thin_client_mode:
if not Script.world.thin_client_mode:
p.batch_size = Script.world.master_job().batch_size
Script.master_start = time.time()
# generate images assigned to local machine
p.do_not_save_grid = True # don't generate grid from master as we are doing this later.
if Script.world.thin_client_mode:
p.batch_size = 0
processed = Processed(p=p, images_list=[])
processed.all_prompts = []
processed.all_seeds = []
processed.all_subseeds = []
processed.all_negative_prompts = []
processed.infotexts = []
processed.prompt = None
Script.runs_since_init += 1
return

View File

@ -1,4 +1,3 @@
import asyncio
import base64
import copy
import io

View File

@ -30,13 +30,6 @@ class NotBenchmarked(Exception):
pass
class WorldAlreadyInitialized(Exception):
"""
Raised when attempting to initialize the World when it has already been initialized.
"""
pass
class Job:
"""
Keeps track of how much work a given worker should handle.
@ -78,7 +71,7 @@ class World:
The frame or "world" which holds all workers (including the local machine).
Args:
initial_payload: The original txt2img payload created by the user initiating the generation request on master.
p: The original processing state object created by the user initiating the generation request on master.
verify_remotes (bool): Whether to validate remote worker certificates.
"""
@ -86,15 +79,14 @@ class World:
config_path = shared.cmd_opts.distributed_config
old_config_path = worker_info_path = extension_path.joinpath('workers.json')
def __init__(self, initial_payload, verify_remotes: bool = True):
def __init__(self, verify_remotes: bool = True):
self.p = None
self.master_worker = Worker(master=True)
self.total_batch_size: int = 0
self._workers: List[Worker] = [self.master_worker]
self.jobs: List[Job] = []
self.job_timeout: int = 3 # seconds
self.initialized: bool = False
self.verify_remotes = verify_remotes
self.initial_payload = copy.copy(initial_payload)
self.thin_client_mode = False
self.enabled = True
self.is_dropdown_handler_injected = False
@ -109,32 +101,12 @@ class World:
def __repr__(self):
return f"{len(self._workers)} workers"
def update_world(self, total_batch_size):
"""
Updates the world with information vital to handling the local generation request after
the world has already been initialized.
Args:
total_batch_size (int): The total number of images requested by the local/master sdwui instance.
"""
self.total_batch_size = total_batch_size
self.update_jobs()
def initialize(self, total_batch_size):
"""should be called before a world instance is used for anything"""
if self.initialized:
raise WorldAlreadyInitialized("This world instance was already initialized")
self.benchmark()
self.update_world(total_batch_size=total_batch_size)
self.initialized = True
def default_batch_size(self) -> int:
"""the amount of images/total images requested that a worker would compute if conditions were perfect and
each worker generated at the same speed. assumes one batch only"""
return self.total_batch_size // self.size()
return self.p.batch_size // self.size()
def size(self) -> int:
"""
@ -396,6 +368,18 @@ class World:
self.jobs.append(Job(worker=worker, batch_size=batch_size))
def update(self, p):
"""preps world for another run"""
if not self.initialized:
self.benchmark()
self.initialized = True
logger.debug("world initialized!")
else:
logger.debug("world was already initialized")
self.p = p
self.update_jobs()
def get_workers(self):
filtered: List[Worker] = []
for worker in self._workers:
@ -431,7 +415,7 @@ class World:
logger.debug(f"worker '{job.worker.label}' would stall the image gallery by ~{lag:.2f}s\n")
job.complementary = True
if deferred_images + images_checked + payload['batch_size'] > self.total_batch_size:
if deferred_images + images_checked + payload['batch_size'] > self.p.batch_size:
logger.debug(f"would go over actual requested size")
else:
deferred_images += payload['batch_size']
@ -474,9 +458,9 @@ class World:
#######################
# when total number of requested images was not cleanly divisible by world size then we tack the remainder on
remainder_images = self.total_batch_size - self.get_current_output_size()
remainder_images = self.p.batch_size - self.get_current_output_size()
if remainder_images >= 1:
logger.debug(f"The requested number of images({self.total_batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed")
logger.debug(f"The requested number of images({self.p.batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed")
realtime_jobs = self.realtime_jobs()
realtime_jobs.sort(key=lambda x: x.batch_size)
@ -555,9 +539,9 @@ class World:
iterations = payload['n_iter']
num_returning = self.get_current_output_size()
num_complementary = num_returning - self.total_batch_size
num_complementary = num_returning - self.p.batch_size
distro_summary = "Job distribution:\n"
distro_summary += f"{self.total_batch_size} * {iterations} iteration(s)"
distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)"
if num_complementary > 0:
distro_summary += f" + {num_complementary} complementary"
distro_summary += f": {num_returning} images total\n"
@ -577,7 +561,7 @@ class World:
processed.infotexts = []
processed.prompt = None
self.initial_payload.scripts.postprocess(p, processed)
self.p.scripts.postprocess(p, processed)
return processed
processing.process_images_inner = process_images_inner_bypass