diff --git a/scripts/distributed.py b/scripts/distributed.py index 492d4d4..524e73b 100644 --- a/scripts/distributed.py +++ b/scripts/distributed.py @@ -61,7 +61,7 @@ class DistributedScript(scripts.Script): return scripts.AlwaysVisible def ui(self, is_img2img): - extension_ui = UI(world=self.world) + extension_ui = UI(world=self.world, is_img2img=is_img2img) # root, api_exposed = extension_ui.create_ui() components = extension_ui.create_ui() @@ -218,8 +218,12 @@ class DistributedScript(scripts.Script): # p's type is # "modules.processing.StableDiffusionProcessing*" def before_process(self, p, *args): - if not self.world.enabled: - logger.debug("extension is disabled") + is_img2img = getattr(p, 'init_images', False) + if is_img2img and self.world.enabled_i2i is False: + logger.debug("extension is disabled for i2i") + return + elif not is_img2img and self.world.enabled is False: + logger.debug("extension is disabled for t2i") return self.world.update(p) @@ -352,7 +356,10 @@ class DistributedScript(scripts.Script): return def postprocess(self, p, processed, *args): - if not self.world.enabled: + is_img2img = getattr(p, 'init_images', False) + if is_img2img and self.world.enabled_i2i is False: + return + elif not is_img2img and self.world.enabled is False: return if self.master_start is not None: diff --git a/scripts/spartan/pmodels.py b/scripts/spartan/pmodels.py index e349dc8..10a5e0a 100644 --- a/scripts/spartan/pmodels.py +++ b/scripts/spartan/pmodels.py @@ -41,5 +41,6 @@ class ConfigModel(BaseModel): ) job_timeout: Optional[int] = Field(default=3) enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True) + enabled_i2i: Optional[bool] = Field(description="Same as above but for image to image", default=True) complement_production: Optional[bool] = Field(description="Whether to generate complementary images to prevent under-utilizing hardware", default=True) step_scaling: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints", default=False) diff --git a/scripts/spartan/ui.py b/scripts/spartan/ui.py index 2535107..82ff504 100644 --- a/scripts/spartan/ui.py +++ b/scripts/spartan/ui.py @@ -17,9 +17,10 @@ worker_select_dropdown = None class UI: """extension user interface related things""" - def __init__(self, world): + def __init__(self, world, is_img2img): self.world = world self.original_model_dropdown_handler = opts.data_labels.get('sd_model_checkpoint').onchange + self.is_img2img = is_img2img # handlers @staticmethod @@ -186,11 +187,15 @@ class UI: self.world.save_config() def main_toggle_btn(self): - self.world.enabled = not self.world.enabled + if self.is_img2img: + self.world.enabled_i2i = not self.world.enabled_i2i + else: + self.world.enabled = not self.world.enabled + self.world.save_config() # restore vanilla sdwui handler for model dropdown if extension is disabled or inject if otherwise - if not self.world.enabled: + if not self.world.enabled and not self.world.enabled_i2i: model_dropdown = opts.data_labels.get('sd_model_checkpoint') if self.original_model_dropdown_handler is not None: model_dropdown.onchange = self.original_model_dropdown_handler @@ -209,9 +214,12 @@ class UI: def create_ui(self): """creates the extension UI and returns relevant components""" components = [] + elem_id = 'enabled' + if self.is_img2img: + elem_id += '_i2i' with gradio.Blocks(variant='compact'): # Group() and Box() remove spacing - with InputAccordion(label='Distributed', open=False, value=self.world.config().get('enabled', False), elem_id='enable') as main_toggle: + with InputAccordion(label='Distributed', open=False, value=self.world.config().get(elem_id), elem_id=elem_id) as main_toggle: main_toggle.input(self.main_toggle_btn) setattr(main_toggle.accordion, 'do_not_save_to_config', True) # InputAccordion is really a CheckBox components.append(main_toggle) diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index bae3605..73cc742 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -91,6 +91,7 @@ class World: self.verify_remotes = verify_remotes self.thin_client_mode = False self.enabled = True + self.enabled_i2i = True self.is_dropdown_handler_injected = False self.complement_production = True self.step_scaling = False @@ -671,6 +672,7 @@ class World: sh.benchmark_payload = Benchmark_Payload(**config.benchmark_payload) self.job_timeout = config.job_timeout self.enabled = config.enabled + self.enabled_i2i = config.enabled_i2i self.complement_production = config.complement_production self.step_scaling = config.step_scaling @@ -686,6 +688,7 @@ class World: benchmark_payload=sh.benchmark_payload, job_timeout=self.job_timeout, enabled=self.enabled, + enabled_i2i=self.enabled_i2i, complement_production=self.complement_production, step_scaling=self.step_scaling )