diff --git a/CHANGELOG.md b/CHANGELOG.md index da76536..bd16d48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,27 @@ # Change Log Formatting: [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), [Semantic Versioning](https://semver.org/spec/v2.0.0.html) +## [2.2.0] - 2024-5-11 + +### Added +- Toggle for allowing automatic step scaling which can increase overall utilization + +### Changed +- Adding workers which have the same socket definition as master will no longer be allowed and an error will show #28 +- Workers in an invalid state should no longer be benchmarked +- The worker port under worker config will now default to 7860 to prevent mishaps +- Config should once again only be loaded once per session startup +- A warning will be shown when trying to use the user script button but no script exists + +### Fixed +- Thin-client mode +- Some problems with sdwui forge branch +- Certificate verification setting sometimes not saving +- Master being assigned no work stopping generation (same problem as thin-client) + +### Removed +- Adding workers using deprecated cmdline argument + ## [2.1.0] - 2024-3-03 ### Added diff --git a/scripts/distributed.py b/scripts/distributed.py index 5733edf..e3d88f8 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -10,7 +10,7 @@ import re import signal import sys import time -from threading import Thread, current_thread +from threading import Thread from typing import List import gradio import urllib3 @@ -18,28 +18,24 @@ from PIL import Image from modules import processing from modules import scripts from modules.images import save_image -from modules.processing import fix_seed, Processed +from modules.processing import fix_seed from modules.shared import opts, cmd_opts from modules.shared import state as webui_state from scripts.spartan.control_net import pack_control_net from scripts.spartan.shared import logger from scripts.spartan.ui import UI -from scripts.spartan.world import World, WorldAlreadyInitialized +from scripts.spartan.world import World old_sigint_handler = signal.getsignal(signal.SIGINT) old_sigterm_handler = signal.getsignal(signal.SIGTERM) -# TODO implement advertisement of some sort in sdwui api to allow extension to automatically discover workers? # noinspection PyMissingOrEmptyDocstring -class Script(scripts.Script): +class DistributedScript(scripts.Script): + # global old_sigterm_handler, old_sigterm_handler worker_threads: List[Thread] = [] # Whether to verify worker certificates. Can be useful if your remotes are self-signed. verify_remotes = not cmd_opts.distributed_skip_verify_remotes - - is_img2img = True - is_txt2img = True - alwayson = True master_start = None runs_since_init = 0 name = "distributed" @@ -50,20 +46,15 @@ class Script(scripts.Script): urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # build world - world = World(initial_payload=None, verify_remotes=verify_remotes) - # add workers to the world + world = World(verify_remotes=verify_remotes) world.load_config() - if cmd_opts.distributed_remotes is not None and len(cmd_opts.distributed_remotes) > 0: - logger.warning(f"--distributed-remotes is deprecated and may be removed in the future\n" - f"gui/external modification of {world.config_path} will be prioritized going forward") - - for worker in cmd_opts.distributed_remotes: - world.add_worker(uuid=worker[0], address=worker[1], port=worker[2], tls=False) - world.save_config() - # do an early check to see which workers are online logger.info("doing initial ping sweep to see which workers are reachable") world.ping_remotes(indiscriminate=True) + # constructed for both txt2img and img2img + def __init__(self): + super().__init__() + def title(self): return "Distribute" @@ -71,21 +62,18 @@ class Script(scripts.Script): return scripts.AlwaysVisible def ui(self, is_img2img): - self.world.load_config() - extension_ui = UI(script=Script, world=Script.world) + extension_ui = UI(world=self.world) # root, api_exposed = extension_ui.create_ui() components = extension_ui.create_ui() # The first injection of handler for the models dropdown(sd_model_checkpoint) which is often present # in the quick-settings bar of a user. Helps ensure model swaps propagate to all nodes ASAP. - Script.world.inject_model_dropdown_handler() + self.world.inject_model_dropdown_handler() # return some components that should be exposed to the api return components - @staticmethod - def add_to_gallery(processed, p): + def add_to_gallery(self, processed, p): """adds generated images to the image gallery after waiting for all workers to finish""" - webui_state.textinfo = "Distributed - injecting images" def processed_inject_image(image, info_index, save_path_override=None, grid=False, response=None): image_params: json = response['parameters'] @@ -129,10 +117,10 @@ class Script(scripts.Script): if p.n_iter > 1: # if splitting by batch count num_remote_images *= p.n_iter - 1 - logger.debug(f"image {true_image_pos + 1}/{Script.world.total_batch_size * p.n_iter}, " + logger.debug(f"image {true_image_pos + 1}/{self.world.p.batch_size * p.n_iter}, " f"info-index: {info_index}") - if Script.world.thin_client_mode: + if self.world.thin_client_mode: p.all_negative_prompts = processed.all_negative_prompts try: @@ -157,19 +145,21 @@ class Script(scripts.Script): ) # get master ipm by estimating based on worker speed - master_elapsed = time.time() - Script.master_start + master_elapsed = time.time() - self.master_start logger.debug(f"Took master {master_elapsed:.2f}s") # wait for response from all workers - for thread in Script.worker_threads: + webui_state.textinfo = "Distributed - receiving results" + for thread in self.worker_threads: logger.debug(f"waiting for worker thread '{thread.name}'") thread.join() - Script.worker_threads.clear() + self.worker_threads.clear() logger.debug("all worker request threads returned") + webui_state.textinfo = "Distributed - injecting images" # some worker which we know has a good response that we can use for generating the grid donor_worker = None - for job in Script.world.jobs: + for job in self.world.jobs: if job.batch_size < 1 or job.worker.master: continue @@ -177,7 +167,7 @@ class Script(scripts.Script): images: json = job.worker.response["images"] # if we for some reason get more than we asked for if (job.batch_size * p.n_iter) < len(images): - logger.debug(f"Requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}") + logger.debug(f"requested {job.batch_size} image(s) from '{job.worker.label}', got {len(images)}") if donor_worker is None: donor_worker = job.worker @@ -208,7 +198,7 @@ class Script(scripts.Script): # generate and inject grid if opts.return_grid: - grid = processing.images.image_grid(processed.images, len(processed.images)) + grid = images.image_grid(processed.images, len(processed.images)) processed_inject_image( image=grid, info_index=0, @@ -218,35 +208,22 @@ class Script(scripts.Script): ) # cleanup after we're doing using all the responses - for worker in Script.world.get_workers(): + for worker in self.world.get_workers(): worker.response = None p.batch_size = len(processed.images) return - @staticmethod - def initialize(initial_payload): - # get default batch size - try: - batch_size = initial_payload.batch_size - except AttributeError: - batch_size = 1 - - try: - Script.world.initialize(batch_size) - logger.debug(f"World initialized!") - except WorldAlreadyInitialized: - Script.world.update_world(total_batch_size=batch_size) - # p's type is - # "modules.processing.StableDiffusionProcessingTxt2Img" + # "modules.processing.StableDiffusionProcessing*" def before_process(self, p, *args): if not self.world.enabled: logger.debug("extension is disabled") return + self.world.update(p) - current_thread().name = "distributed_main" - Script.initialize(initial_payload=p) + # save original process_images_inner function for later if we monkeypatch it + self.original_process_images_inner = processing.process_images_inner # strip scripts that aren't yet supported and warn user packed_script_args: List[dict] = [] # list of api formatted per-script argument objects @@ -261,8 +238,9 @@ class Script(scripts.Script): # grab all controlnet units cn_units = [] cn_args = p.script_args[script.args_from:script.args_to] + for cn_arg in cn_args: - if type(cn_arg).__name__ == "UiControlNetUnit": + if "ControlNetUnit" in type(cn_arg).__name__: cn_units.append(cn_arg) logger.debug(f"Detected {len(cn_units)} controlnet unit(s)") @@ -281,7 +259,7 @@ class Script(scripts.Script): # encapsulating the request object within a txt2imgreq object is deprecated and no longer works # see test/basic_features/txt2img_test.py for an example payload = copy.copy(p.__dict__) - payload['batch_size'] = Script.world.default_batch_size() + payload['batch_size'] = self.world.default_batch_size() payload['scripts'] = None try: del payload['script_args'] @@ -300,8 +278,9 @@ class Script(scripts.Script): # TODO api for some reason returns 200 even if something failed to be set. # for now we may have to make redundant GET requests to check if actually successful... # https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8146 + name = re.sub(r'\s?\[[^]]*]$', '', opts.data["sd_model_checkpoint"]) - vae = opts.data["sd_vae"] + vae = opts.data.get('sd_vae') option_payload = { "sd_model_checkpoint": name, "sd_vae": vae @@ -309,11 +288,11 @@ class Script(scripts.Script): # start generating images assigned to remote machines sync = False # should only really need to sync once per job - Script.world.optimize_jobs(payload) # optimize work assignment before dispatching + self.world.optimize_jobs(payload) # optimize work assignment before dispatching started_jobs = [] # check if anything even needs to be done - if len(Script.world.jobs) == 1 and Script.world.jobs[0].worker.master: + if len(self.world.jobs) == 1 and self.world.jobs[0].worker.master: if payload['batch_size'] >= 2: msg = f"all remote workers are offline or unreachable" @@ -324,7 +303,7 @@ class Script(scripts.Script): return - for job in Script.world.jobs: + for job in self.world.jobs: payload_temp = copy.copy(payload) del payload_temp['scripts_value'] payload_temp = copy.deepcopy(payload_temp) @@ -339,11 +318,14 @@ class Script(scripts.Script): prior_images += j.batch_size * p.n_iter payload_temp['batch_size'] = job.batch_size + if job.step_override is not None: + payload_temp['steps'] = job.step_override payload_temp['subseed'] += prior_images payload_temp['seed'] += prior_images if payload_temp['subseed_strength'] == 0 else 0 logger.debug( f"'{job.worker.label}' job's given starting seed is " - f"{payload_temp['seed']} with {prior_images} coming before it") + f"{payload_temp['seed']} with {prior_images} coming before it" + ) if job.worker.loaded_model != name or job.worker.loaded_vae != vae: sync = True @@ -354,42 +336,34 @@ class Script(scripts.Script): name=f"{job.worker.label}_request") t.start() - Script.worker_threads.append(t) + self.worker_threads.append(t) started_jobs.append(job) # if master batch size was changed again due to optimization change it to the updated value if not self.world.thin_client_mode: - p.batch_size = Script.world.master_job().batch_size - Script.master_start = time.time() + p.batch_size = self.world.master_job().batch_size + self.master_start = time.time() # generate images assigned to local machine p.do_not_save_grid = True # don't generate grid from master as we are doing this later. - if Script.world.thin_client_mode: - p.batch_size = 0 - processed = Processed(p=p, images_list=[]) - processed.all_prompts = [] - processed.all_seeds = [] - processed.all_subseeds = [] - processed.all_negative_prompts = [] - processed.infotexts = [] - processed.prompt = None - - Script.runs_since_init += 1 + self.runs_since_init += 1 return - @staticmethod - def postprocess(p, processed, *args): - if not Script.world.enabled: + def postprocess(self, p, processed, *args): + if not self.world.enabled: return - if len(processed.images) >= 1 and Script.master_start is not None: - Script.add_to_gallery(p=p, processed=processed) + if self.master_start is not None: + self.add_to_gallery(p=p, processed=processed) + + # restore process_images_inner if it was monkey-patched + processing.process_images_inner = self.original_process_images_inner @staticmethod def signal_handler(sig, frame): logger.debug("handling interrupt signal") # do cleanup - Script.world.save_config() + DistributedScript.world.save_config() if sig == signal.SIGINT: if callable(old_sigint_handler): diff --git a/scripts/spartan/control_net.py b/scripts/spartan/control_net.py index 18a240e..dc55cd9 100644 --- a/scripts/spartan/control_net.py +++ b/scripts/spartan/control_net.py @@ -2,6 +2,16 @@ import copy from PIL import Image from modules.api.api import encode_pil_to_base64 from scripts.spartan.shared import logger +import numpy as np +import json + + +def np_to_b64(image: np.ndarray): + pil = Image.fromarray(image) + image_b64 = str(encode_pil_to_base64(pil), 'utf-8') + image_b64 = 'data:image/png;base64,' + image_b64 + + return image_b64 def pack_control_net(cn_units) -> dict: @@ -17,28 +27,46 @@ def pack_control_net(cn_units) -> dict: cn_args = controlnet['controlnet']['args'] for i in range(0, len(cn_units)): - # copy control net unit to payload - cn_args.append(copy.copy(cn_units[i].__dict__)) + if cn_units[i].enabled: + cn_args.append(copy.deepcopy(cn_units[i].__dict__)) + else: + logger.debug(f"controlnet unit {i} is not enabled (ignoring)") + + for i in range(0, len(cn_args)): unit = cn_args[i] - # if unit isn't enabled then don't bother including - if not unit['enabled']: - del unit['input_mode'] - del unit['image'] - logger.debug(f"Controlnet unit {i} is not enabled. Ignoring") - continue - # serialize image - if unit['image'] is not None: - image = unit['image']['image'] - # mask = unit['image']['mask'] - pil = Image.fromarray(image) - image_b64 = encode_pil_to_base64(pil) - image_b64 = str(image_b64, 'utf-8') - unit['input_image'] = image_b64 + image_pair = unit.get('image') + if image_pair is not None: + image_b64 = np_to_b64(image_pair['image']) + unit['input_image'] = image_b64 # mikubill + unit['image'] = image_b64 # forge + if np.all(image_pair['mask'] == 0): + # stand-alone mask from second gradio component + standalone_mask = unit.get('mask_image') + if standalone_mask is not None: + logger.debug(f"found stand-alone mask for controlnet unit {i}") + mask_b64 = np_to_b64(unit['mask_image']['mask']) + unit['mask'] = mask_b64 # mikubill + unit['mask_image'] = mask_b64 # forge + + else: + # mask from singular gradio image component + logger.debug(f"found mask for controlnet unit {i}") + mask_b64 = np_to_b64(image_pair['mask']) + unit['mask'] = mask_b64 # mikubill + unit['mask_image'] = mask_b64 # forge + + # avoid returning duplicate detection maps since master should return the same one + unit['save_detected_map'] = False # remove anything unserializable del unit['input_mode'] - del unit['image'] + + try: + json.dumps(controlnet) + except Exception as e: + logger.error(f"failed to serialize controlnet\nfirst unit:\n{controlnet['controlnet']['args'][0]}") + return {} return controlnet diff --git a/scripts/spartan/pmodels.py b/scripts/spartan/pmodels.py index 7e2bf54..e349dc8 100644 --- a/scripts/spartan/pmodels.py +++ b/scripts/spartan/pmodels.py @@ -42,3 +42,4 @@ class ConfigModel(BaseModel): job_timeout: Optional[int] = Field(default=3) enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True) complement_production: Optional[bool] = Field(description="Whether to generate complementary images to prevent under-utilizing hardware", default=True) + step_scaling: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints", default=False) diff --git a/scripts/spartan/shared.py b/scripts/spartan/shared.py index 26e7a2b..5400464 100644 --- a/scripts/spartan/shared.py +++ b/scripts/spartan/shared.py @@ -61,6 +61,7 @@ logger.addHandler(gui_handler) # end logging warmup_samples = 2 # number of samples to do before recording a valid benchmark sample +samples = 3 # number of times to benchmark worker after warmup benchmarks are completed class BenchmarkPayload(BaseModel): diff --git a/scripts/spartan/ui.py b/scripts/spartan/ui.py index 2d1cfbe..cd0b87a 100644 --- a/scripts/spartan/ui.py +++ b/scripts/spartan/ui.py @@ -14,8 +14,7 @@ worker_select_dropdown = None class UI: """extension user interface related things""" - def __init__(self, script, world): - self.script = script + def __init__(self, world): self.world = world self.original_model_dropdown_handler = opts.data_labels.get('sd_model_checkpoint').onchange @@ -25,10 +24,17 @@ class UI: """executes a script placed by the user at /user/sync*""" user_scripts = Path(os.path.abspath(__file__)).parent.parent.joinpath('user') + user_script = None for file in user_scripts.iterdir(): logger.debug(f"found possible script {file.name}") if file.is_file() and file.name.startswith('sync'): user_script = file + if user_script is None: + logger.error( + "couldn't find user script\n" + "script must be placed under /user/ and filename must begin with sync" + ) + return False suffix = user_script.suffix[1:] @@ -74,15 +80,14 @@ class UI: return 'No active jobs!', worker_status, logs - def save_btn(self, thin_client_mode, job_timeout, complement_production): + def save_btn(self, thin_client_mode, job_timeout, complement_production, step_scaling): """updates the options visible on the settings tab""" self.world.thin_client_mode = thin_client_mode - logger.debug(f"thin client mode is now {thin_client_mode}") job_timeout = int(job_timeout) self.world.job_timeout = job_timeout - logger.debug(f"job timeout is now {job_timeout} seconds") self.world.complement_production = complement_production + self.world.step_scaling = step_scaling self.world.save_config() def save_worker_btn(self, label, address, port, tls, disabled): @@ -102,7 +107,7 @@ class UI: self.world.add_worker( label=label, address=address, - port=port, + port=port if len(port) > 0 else 7860, tls=tls, state=state ) @@ -208,7 +213,6 @@ class UI: interactive=True ) main_toggle.input(self.main_toggle_btn) - setattr(main_toggle, 'do_not_save_to_config', True) # ui_loadsave.py apply_field() components.append(main_toggle) with gradio.Tab('Status') as status_tab: @@ -240,7 +244,7 @@ class UI: reload_config_btn = gradio.Button(value='πŸ“œ Reload config') reload_config_btn.click(self.world.load_config) - redo_benchmarks_btn = gradio.Button(value='πŸ“Š Redo benchmarks', variant='stop') + redo_benchmarks_btn = gradio.Button(value='πŸ“Š Redo benchmarks') redo_benchmarks_btn.click(self.benchmark_btn, inputs=[], outputs=[]) run_usr_btn = gradio.Button(value='βš™οΈ Run script') @@ -252,10 +256,10 @@ class UI: reconnect_lost_workers_btn = gradio.Button(value='πŸ”Œ Reconnect workers') reconnect_lost_workers_btn.click(self.world.ping_remotes) - interrupt_all_btn = gradio.Button(value='⏸️ Interrupt all', variant='stop') + interrupt_all_btn = gradio.Button(value='⏸️ Interrupt all') interrupt_all_btn.click(self.world.interrupt_remotes) - restart_workers_btn = gradio.Button(value="πŸ” Restart All", variant='stop') + restart_workers_btn = gradio.Button(value="πŸ” Restart All") restart_workers_btn.click( _js="confirm_restart_workers", fn=lambda confirmed: self.world.restart_all() if confirmed else None, @@ -305,7 +309,7 @@ class UI: # API authentication worker_api_auth_cbx = gradio.Checkbox(label='API Authentication') worker_user_field = gradio.Textbox(label='Username') - worker_password_field = gradio.Textbox(label='Password') + worker_password_field = gradio.Textbox(label='Password', type='password') update_credentials_btn = gradio.Button(value='Update API Credentials') update_credentials_btn.click(self.update_credentials_btn, inputs=[ worker_api_auth_cbx, @@ -346,25 +350,33 @@ class UI: with gradio.Tab('Settings'): thin_client_cbx = gradio.Checkbox( - label='Thin-client mode (experimental)', - info="(BROKEN) Only generate images using remote workers. There will be no previews when enabled.", + label='Thin-client mode', + info="Only generate images remotely (no image previews yet)", value=self.world.thin_client_mode ) job_timeout = gradio.Number( label='Job timeout', value=self.world.job_timeout, info="Seconds until a worker is considered too slow to be assigned an" - " equal share of the total request. Longer than 2 seconds is recommended." + " equal share of the total request. Longer than 2 seconds is recommended" ) complement_production = gradio.Checkbox( label='Complement production', - info='Prevents under-utilization of hardware by requesting additional images', + info='Prevents under-utilization by requesting additional images when possible', value=self.world.complement_production ) + # reduces image quality the more the sample-count must be reduced + # good for mixed setups where each worker may not be around the same speed + step_scaling = gradio.Checkbox( + label='Step scaling', + info='Prevents under-utilization via sample reduction in order to meet time constraints', + value=self.world.step_scaling + ) + save_btn = gradio.Button(value='Update') - save_btn.click(fn=self.save_btn, inputs=[thin_client_cbx, job_timeout, complement_production]) - components += [thin_client_cbx, job_timeout, complement_production, save_btn] + save_btn.click(fn=self.save_btn, inputs=[thin_client_cbx, job_timeout, complement_production, step_scaling]) + components += [thin_client_cbx, job_timeout, complement_production, step_scaling, save_btn] with gradio.Tab('Help'): gradio.Markdown( @@ -374,4 +386,7 @@ class UI: """ ) + # prevent wui from overriding any values + for component in components: + setattr(component, 'do_not_save_to_config', True) # ui_loadsave.py apply_field() return components diff --git a/scripts/spartan/worker.py b/scripts/spartan/worker.py index 85316f4..39cf884 100644 --- a/scripts/spartan/worker.py +++ b/scripts/spartan/worker.py @@ -156,6 +156,11 @@ class Worker: def __repr__(self): return f"'{self.label}'@{self.address}:{self.port}, speed: {self.avg_ipm} ipm, state: {self.state}" + def __eq__(self, other): + if isinstance(other, Worker) and other.label == self.label: + return True + return False + @property def model(self) -> Worker_Model: return Worker_Model(**self.__dict__) @@ -189,14 +194,14 @@ class Worker: protocol = 'http' if not self.tls else 'https' return f"{protocol}://{self.__str__()}/sdapi/v1/{route}" - def batch_eta_hr(self, payload: dict) -> float: + def eta_hr(self, payload: dict) -> float: """ takes a normal payload and returns the eta of a pseudo payload which mirrors the hr-fix parameters This returns the eta of how long it would take to run hr-fix on the original image """ pseudo_payload = copy.copy(payload) - pseudo_payload['enable_hr'] = False # prevent overflow in self.batch_eta + pseudo_payload['enable_hr'] = False # prevent overflow in self.eta res_ratio = pseudo_payload['hr_scale'] original_steps = pseudo_payload['steps'] second_pass_steps = pseudo_payload['hr_second_pass_steps'] @@ -212,12 +217,11 @@ class Worker: pseudo_payload['width'] = pseudo_width pseudo_payload['height'] = pseudo_height - eta = self.batch_eta(payload=pseudo_payload, quiet=True) - return eta + return self.eta(payload=pseudo_payload, quiet=True) - def batch_eta(self, payload: dict, quiet: bool = False, batch_size: int = None) -> float: + def eta(self, payload: dict, quiet: bool = False, batch_size: int = None, samples: int = None) -> float: """ - estimate how long it will take to generate images on a worker in seconds + estimate how long it will take to generate image(s) on a worker in seconds Args: payload: Sdwui api formatted payload @@ -225,7 +229,7 @@ class Worker: batch_size: Overrides the batch_size parameter of the payload """ - steps = payload['steps'] + steps = payload['steps'] if samples is None else samples num_images = payload['batch_size'] if batch_size is None else batch_size # if worker has not yet been benchmarked then @@ -237,7 +241,7 @@ class Worker: # show effect of high-res fix hr = payload.get('enable_hr', False) if hr: - eta += self.batch_eta_hr(payload=payload) + eta += self.eta_hr(payload=payload) # show effect of image size real_pix_to_benched = (payload['width'] * payload['height']) \ @@ -331,7 +335,7 @@ class Worker: self.load_options(model=option_payload['sd_model_checkpoint'], vae=option_payload['sd_vae']) if self.benchmarked: - eta = self.batch_eta(payload=payload) * payload['n_iter'] + eta = self.eta(payload=payload) * payload['n_iter'] logger.debug(f"worker '{self.label}' predicts it will take {eta:.3f}s to generate " f"{payload['batch_size'] * payload['n_iter']} image(s) " f"at a speed of {self.avg_ipm:.2f} ipm\n") @@ -471,7 +475,7 @@ class Worker: self.response_time = time.time() - start variance = ((eta - self.response_time) / self.response_time) * 100 - logger.debug(f"Worker '{self.label}'s ETA was off by {variance:.2f}%.\n" + logger.debug(f"Worker '{self.label}'s ETA was off by {variance:.2f}%\n" f"Predicted {eta:.2f}s. Actual: {self.response_time:.2f}s\n") # if the variance is greater than 500% then we ignore it to prevent variation inflation @@ -501,20 +505,20 @@ class Worker: self.jobs_requested += 1 return - def benchmark(self) -> float: + def benchmark(self, sample_function: callable = None) -> float: """ given a worker, run a small benchmark and return its performance in images/minute makes standard request(s) of 512x512 images and averages them to get the result """ t: Thread - samples = 2 # number of times to benchmark the remote / accuracy - if self.state == State.DISABLED or self.state == State.UNAVAILABLE: + if self.state in (State.DISABLED, State.UNAVAILABLE): logger.debug(f"worker '{self.label}' is unavailable or disabled, refusing to benchmark") return 0 - if self.master is True: + if self.master and sample_function is None: + logger.critical(f"no function provided for benchmarking master") return -1 def ipm(seconds: float) -> float: @@ -533,19 +537,24 @@ class Worker: results: List[float] = [] # it used to be lower for the first couple of generations # this was due to something torch does at startup according to auto and is now done at sdwui startup - self.state = State.WORKING - for i in range(0, samples + warmup_samples): # run some extra times so that the remote can "warm up" + for i in range(0, sh.samples + warmup_samples): # run some extra times so that the remote can "warm up" if self.state == State.UNAVAILABLE: self.response = None return 0 - t = Thread(target=self.request, args=(dict(sh.benchmark_payload), None, False,), - name=f"{self.label}_benchmark_request") try: # if the worker is unreachable/offline then handle that here - t.start() - start = time.time() - t.join() - elapsed = time.time() - start + elapsed = None + + if not callable(sample_function): + start = time.time() + t = Thread(target=self.request, args=(dict(sh.benchmark_payload), None, False,), + name=f"{self.label}_benchmark_request") + t.start() + t.join() + elapsed = time.time() - start + else: + elapsed = sample_function() + sample_ipm = ipm(elapsed) except InvalidWorkerResponse as e: raise e @@ -558,15 +567,13 @@ class Worker: logger.debug(f"{self.label} finished warming up\n") # average the sample results for accuracy - ipm_sum = 0 - for ipm_result in results: - ipm_sum += ipm_result - avg_ipm_result = ipm_sum / samples + avg_ipm_result = sum(results) / sh.samples logger.debug(f"Worker '{self.label}' average ipm: {avg_ipm_result:.2f}") self.avg_ipm = avg_ipm_result self.response = None self.benchmarked = True + self.eta_percent_error = [] # likely inaccurate after rebenching self.state = State.IDLE return avg_ipm_result @@ -610,6 +617,10 @@ class Worker: except requests.exceptions.ConnectionError as e: logger.error(e) return False + except requests.ReadTimeout as e: + logger.critical(f"worker '{self.label}' is online but not responding (crashed?)") + logger.error(e) + return False def mark_unreachable(self): if self.state == State.DISABLED: @@ -677,6 +688,10 @@ class Worker: if vae is not None: self.loaded_vae = vae + self.response = response + + return self + def restart(self) -> bool: err_msg = f"could not restart worker '{self.label}'" success_msg = f"worker '{self.label}' is restarting" diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 0646a40..f555ca1 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -4,7 +4,7 @@ This module facilitates the creation of a stable-diffusion-webui centered distri World: The main class which should be instantiated in order to create a new sdwui distributed system. """ - +import concurrent.futures import copy import json import os @@ -16,8 +16,10 @@ import modules.shared as shared from modules.processing import process_images, StableDiffusionProcessingTxt2Img from . import shared as sh from .pmodels import ConfigModel, Benchmark_Payload -from .shared import logger, warmup_samples, extension_path +from .shared import logger, extension_path from .worker import Worker, State +from modules.call_queue import wrap_queued_call +from modules import processing class NotBenchmarked(Exception): @@ -28,13 +30,6 @@ class NotBenchmarked(Exception): pass -class WorldAlreadyInitialized(Exception): - """ - Raised when attempting to initialize the World when it has already been initialized. - """ - pass - - class Job: """ Keeps track of how much work a given worker should handle. @@ -48,6 +43,7 @@ class Job: self.worker: Worker = worker self.batch_size: int = batch_size self.complementary: bool = False + self.step_override = None def __str__(self): prefix = '' @@ -75,7 +71,7 @@ class World: The frame or "world" which holds all workers (including the local machine). Args: - initial_payload: The original txt2img payload created by the user initiating the generation request on master. + p: The original processing state object created by the user initiating the generation request on master. verify_remotes (bool): Whether to validate remote worker certificates. """ @@ -83,19 +79,19 @@ class World: config_path = shared.cmd_opts.distributed_config old_config_path = worker_info_path = extension_path.joinpath('workers.json') - def __init__(self, initial_payload, verify_remotes: bool = True): + def __init__(self, verify_remotes: bool = True): + self.p = None self.master_worker = Worker(master=True) - self.total_batch_size: int = 0 self._workers: List[Worker] = [self.master_worker] self.jobs: List[Job] = [] self.job_timeout: int = 3 # seconds self.initialized: bool = False self.verify_remotes = verify_remotes - self.initial_payload = copy.copy(initial_payload) self.thin_client_mode = False self.enabled = True self.is_dropdown_handler_injected = False self.complement_production = True + self.step_scaling = False def __getitem__(self, label: str) -> Worker: for worker in self._workers: @@ -105,32 +101,12 @@ class World: def __repr__(self): return f"{len(self._workers)} workers" - def update_world(self, total_batch_size): - """ - Updates the world with information vital to handling the local generation request after - the world has already been initialized. - - Args: - total_batch_size (int): The total number of images requested by the local/master sdwui instance. - """ - - self.total_batch_size = total_batch_size - self.update_jobs() - - def initialize(self, total_batch_size): - """should be called before a world instance is used for anything""" - if self.initialized: - raise WorldAlreadyInitialized("This world instance was already initialized") - - self.benchmark() - self.update_world(total_batch_size=total_batch_size) - self.initialized = True def default_batch_size(self) -> int: """the amount of images/total images requested that a worker would compute if conditions were perfect and each worker generated at the same speed. assumes one batch only""" - return self.total_batch_size // self.size() + return self.p.batch_size // self.size() def size(self) -> int: """ @@ -169,6 +145,14 @@ class World: Worker: The worker object. """ + # protect against user trying to make cyclical setups and connections + is_master = kwargs.get('master') + if is_master is None or not is_master: + m = self.master() + if kwargs['address'] == m.address and kwargs['port'] == m.port: + logger.error(f"refusing to add worker {kwargs['label']} as its socket definition({m.address}:{m.port}) matches master") + return None + original = self[kwargs['label']] # if worker doesn't already exist then just make a new one if original is None: new = Worker(**kwargs) @@ -192,16 +176,33 @@ class World: if worker.master: continue - t = Thread(target=worker.interrupt, args=()) - t.start() + Thread(target=worker.interrupt, args=()).start() def refresh_checkpoints(self): for worker in self.get_workers(): if worker.master: continue - t = Thread(target=worker.refresh_checkpoints, args=()) - t.start() + Thread(target=worker.refresh_checkpoints, args=()).start() + + def sample_master(self) -> float: + # wrap our benchmark payload + master_bench_payload = StableDiffusionProcessingTxt2Img() + d = sh.benchmark_payload.dict() + for key in d: + setattr(master_bench_payload, key, d[key]) + + # Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to. + master_bench_payload.do_not_save_samples = True + # shared.state.begin(job='distributed_master_bench') + wrapped = (wrap_queued_call(process_images)) + start = time.time() + wrapped(master_bench_payload) + # wrap_gradio_gpu_call(process_images)(master_bench_payload) + # shared.state.end() + + return time.time() - start + def benchmark(self, rebenchmark: bool = False): """ @@ -209,14 +210,6 @@ class World: """ unbenched_workers = [] - benchmark_threads: List[Thread] = [] - sync_threads: List[Thread] = [] - - def benchmark_wrapped(worker): - bench_func = worker.benchmark if not worker.master else self.benchmark_master - worker.avg_ipm = bench_func() - worker.benchmarked = True - if rebenchmark: for worker in self._workers: worker.benchmarked = False @@ -231,28 +224,44 @@ class World: else: worker.benchmarked = True - # have every unbenched worker load the same weights before the benchmark - for worker in unbenched_workers: - if worker.master or worker.state == State.DISABLED: - continue + with concurrent.futures.ThreadPoolExecutor(thread_name_prefix='distributed_benchmark') as executor: + futures = [] - sync_thread = Thread(target=worker.load_options, args=(shared.opts.sd_model_checkpoint, shared.opts.sd_vae)) - sync_threads.append(sync_thread) - sync_thread.start() - for thread in sync_threads: - thread.join() + # have every unbenched worker load the same weights before the benchmark + for worker in unbenched_workers: + if worker.master or worker.state in (State.DISABLED, State.UNAVAILABLE): + continue - # benchmark those that haven't been - for worker in unbenched_workers: - t = Thread(target=benchmark_wrapped, args=(worker, ), name=f"{worker.label}_benchmark") - benchmark_threads.append(t) - t.start() - logger.info(f"benchmarking worker '{worker.label}'") + futures.append( + executor.submit(worker.load_options, model=shared.opts.sd_model_checkpoint, vae=shared.opts.sd_vae) + ) + for future in concurrent.futures.as_completed(futures): + worker = future.result() + if worker is None: + continue - # wait for all benchmarks to finish and update stats on newly benchmarked workers - if len(benchmark_threads) > 0: - for t in benchmark_threads: - t.join() + if worker.response.status_code != 200: + logger.error(f"refusing to benchmark worker '{worker.label}' as it failed to load the selected model '{shared.opts.sd_model_checkpoint}'\n" + f"*you may circumvent this by using the per-worker model override setting but this is not recommended as the same benchmark model should be used for all workers") + unbenched_workers = list(filter(lambda w: w != worker, unbenched_workers)) + futures.clear() + + # benchmark those that haven't been + for worker in unbenched_workers: + if worker.state in (State.DISABLED, State.UNAVAILABLE): + logger.debug(f"worker '{worker.label}' is {worker.state}, refusing to benchmark") + continue + + if worker.model_override is not None: + logger.warning(f"model override is enabled for worker '{worker.label}' which may result in poor optimization\n" + f"*all workers should be evaluated against the same model") + + chosen = worker.benchmark if not worker.master else worker.benchmark(sample_function=self.sample_master) + futures.append(executor.submit(chosen, worker)) + logger.info(f"benchmarking worker '{worker.label}'") + + # wait for all benchmarks to finish and update stats on newly benchmarked workers + concurrent.futures.wait(futures) logger.info("benchmarking finished") # save benchmark results to workers.json @@ -348,43 +357,11 @@ class World: if worker == fastest_worker: return 0 - lag = worker.batch_eta(payload=payload, quiet=True, batch_size=batch_size) - fastest_worker.batch_eta(payload=payload, quiet=True, batch_size=batch_size) + lag = worker.eta(payload=payload, quiet=True, batch_size=batch_size) - fastest_worker.eta(payload=payload, quiet=True, batch_size=batch_size) return lag - def benchmark_master(self) -> float: - """ - Benchmarks the local/master worker. - - Returns: - float: Local worker speed in ipm - """ - - # wrap our benchmark payload - master_bench_payload = StableDiffusionProcessingTxt2Img() - d = sh.benchmark_payload.dict() - for key in d: - setattr(master_bench_payload, key, d[key]) - - # Keeps from trying to save the images when we don't know the path. Also, there's not really any reason to. - master_bench_payload.do_not_save_samples = True - - # "warm up" due to initial generation lag - for _ in range(warmup_samples): - process_images(master_bench_payload) - - # get actual sample - start = time.time() - process_images(master_bench_payload) - elapsed = time.time() - start - - ipm = sh.benchmark_payload.batch_size / (elapsed / 60) - - logger.debug(f"Master benchmark took {elapsed:.2f}: {ipm:.2f} ipm") - self.master().benchmarked = True - return ipm - - def update_jobs(self): + def make_jobs(self): """creates initial jobs (before optimization) """ # clear jobs if this is not the first time running @@ -398,6 +375,19 @@ class World: worker.benchmark() self.jobs.append(Job(worker=worker, batch_size=batch_size)) + logger.debug(f"added job for worker {worker.label}") + + def update(self, p): + """preps world for another run""" + if not self.initialized: + self.benchmark() + self.initialized = True + logger.debug("world initialized!") + else: + logger.debug("world was already initialized") + + self.p = p + self.make_jobs() def get_workers(self): filtered: List[Worker] = [] @@ -434,7 +424,7 @@ class World: logger.debug(f"worker '{job.worker.label}' would stall the image gallery by ~{lag:.2f}s\n") job.complementary = True - if deferred_images + images_checked + payload['batch_size'] > self.total_batch_size: + if deferred_images + images_checked + payload['batch_size'] > self.p.batch_size: logger.debug(f"would go over actual requested size") else: deferred_images += payload['batch_size'] @@ -477,9 +467,9 @@ class World: ####################### # when total number of requested images was not cleanly divisible by world size then we tack the remainder on - remainder_images = self.total_batch_size - self.get_current_output_size() + remainder_images = self.p.batch_size - self.get_current_output_size() if remainder_images >= 1: - logger.debug(f"The requested number of images({self.total_batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed") + logger.debug(f"The requested number of images({self.p.batch_size}) was not cleanly divisible by the number of realtime nodes({len(self.realtime_jobs())}) resulting in {remainder_images} that will be redistributed") realtime_jobs = self.realtime_jobs() realtime_jobs.sort(key=lambda x: x.batch_size) @@ -521,16 +511,16 @@ class World: fastest_active = self.fastest_realtime_job().worker for j in self.jobs: if j.worker.label == fastest_active.label: - slack_time = fastest_active.batch_eta(payload=payload, batch_size=j.batch_size) + self.job_timeout + slack_time = fastest_active.eta(payload=payload, batch_size=j.batch_size) + self.job_timeout logger.debug(f"There's {slack_time:.2f}s of slack time available for worker '{job.worker.label}'") # see how long it would take to produce only 1 image on this complementary worker - secs_per_batch_image = job.worker.batch_eta(payload=payload, batch_size=1) + secs_per_batch_image = job.worker.eta(payload=payload, batch_size=1) num_images_compensate = int(slack_time / secs_per_batch_image) logger.debug( f"worker '{job.worker.label}':\n" f"{num_images_compensate} complementary image(s) = {slack_time:.2f}s slack" - f"/ {secs_per_batch_image:.2f}s per requested image" + f" Γ· {secs_per_batch_image:.2f}s per requested image" ) if not job.add_work(payload, batch_size=num_images_compensate): @@ -538,14 +528,29 @@ class World: request_img_size = payload['width'] * payload['height'] max_images = job.worker.pixel_cap // request_img_size job.add_work(payload, batch_size=max_images) + + # when not even a singular image can be squeezed out + # if step scaling is enabled, then find how many samples would be considered realtime and adjust + if num_images_compensate == 0 and self.step_scaling: + seconds_per_sample = job.worker.eta(payload=payload, batch_size=1, samples=1) + realtime_samples = slack_time // seconds_per_sample + logger.debug( + f"job for '{job.worker.label}' downscaled to {realtime_samples} samples to meet time constraints\n" + f"{realtime_samples:.0f} samples = {slack_time:.2f}s slack Γ· {seconds_per_sample:.2f}s/sample\n" + f" step reduction: {payload['steps']} -> {realtime_samples:.0f}" + ) + + job.add_work(payload=payload, batch_size=1) + job.step_override = realtime_samples + else: logger.debug("complementary image production is disabled") iterations = payload['n_iter'] num_returning = self.get_current_output_size() - num_complementary = num_returning - self.total_batch_size + num_complementary = num_returning - self.p.batch_size distro_summary = "Job distribution:\n" - distro_summary += f"{self.total_batch_size} * {iterations} iteration(s)" + distro_summary += f"{self.p.batch_size} * {iterations} iteration(s)" if num_complementary > 0: distro_summary += f" + {num_complementary} complementary" distro_summary += f": {num_returning} images total\n" @@ -553,6 +558,22 @@ class World: distro_summary += f"'{job.worker.label}' - {job.batch_size * iterations} image(s) @ {job.worker.avg_ipm:.2f} ipm\n" logger.info(distro_summary) + if self.thin_client_mode is True or self.master_job().batch_size == 0: + # save original process_images_inner for later so we can restore once we're done + logger.debug(f"bypassing local generation completely") + def process_images_inner_bypass(p) -> processing.Processed: + processed = processing.Processed(p, [], p.seed, info="") + processed.all_prompts = [] + processed.all_seeds = [] + processed.all_subseeds = [] + processed.all_negative_prompts = [] + processed.infotexts = [] + processed.prompt = None + + self.p.scripts.postprocess(p, processed) + return processed + processing.process_images_inner = process_images_inner_bypass + # delete any jobs that have no work last = len(self.jobs) - 1 while last > 0: @@ -631,6 +652,8 @@ class World: label = next(iter(w.keys())) fields = w[label].__dict__ fields['label'] = label + # TODO must be overridden everytime here or later converted to a config file variable at some point + fields['verify_remotes'] = self.verify_remotes self.add_worker(**fields) @@ -638,6 +661,7 @@ class World: self.job_timeout = config.job_timeout self.enabled = config.enabled self.complement_production = config.complement_production + self.step_scaling = config.step_scaling logger.debug("config loaded") @@ -651,7 +675,8 @@ class World: benchmark_payload=sh.benchmark_payload, job_timeout=self.job_timeout, enabled=self.enabled, - complement_production=self.complement_production + complement_production=self.complement_production, + step_scaling=self.step_scaling ) with open(self.config_path, 'w+') as config_file: @@ -679,22 +704,24 @@ class World: if worker.queried and worker.state == State.IDLE: # TODO worker.queried continue - # for now skip/remove scripts that are not "always on" since there is currently no way to run - # them at the same time as distributed supported_scripts = { 'txt2img': [], 'img2img': [] } - script_info = worker.session.get(url=worker.full_url('script-info')).json() - for key in script_info: - name = key.get('name', None) + response = worker.session.get(url=worker.full_url('script-info')) + if response.status_code == 200: + script_info = response.json() + for key in script_info: + name = key.get('name', None) - if name is not None: - is_alwayson = key.get('is_alwayson', False) - is_img2img = key.get('is_img2img', False) - if is_alwayson: - supported_scripts['img2img' if is_img2img else 'txt2img'].append(name) + if name is not None: + is_alwayson = key.get('is_alwayson', False) + is_img2img = key.get('is_img2img', False) + if is_alwayson: + supported_scripts['img2img' if is_img2img else 'txt2img'].append(name) + else: + logger.error(f"failed to query script-info for worker '{worker.label}': {response}") worker.supported_scripts = supported_scripts msg = f"worker '{worker.label}' is online"