allow t2i and i2i to be enabled/disabled separately from one another

master^2
papuSpartan 2024-09-22 15:22:32 -05:00
parent ce2d3254d4
commit 429adb2659
No known key found for this signature in database
GPG Key ID: CA376082283AF69A
4 changed files with 27 additions and 8 deletions

View File

@ -61,7 +61,7 @@ class DistributedScript(scripts.Script):
return scripts.AlwaysVisible return scripts.AlwaysVisible
def ui(self, is_img2img): def ui(self, is_img2img):
extension_ui = UI(world=self.world) extension_ui = UI(world=self.world, is_img2img=is_img2img)
# root, api_exposed = extension_ui.create_ui() # root, api_exposed = extension_ui.create_ui()
components = extension_ui.create_ui() components = extension_ui.create_ui()
@ -218,8 +218,12 @@ class DistributedScript(scripts.Script):
# p's type is # p's type is
# "modules.processing.StableDiffusionProcessing*" # "modules.processing.StableDiffusionProcessing*"
def before_process(self, p, *args): def before_process(self, p, *args):
if not self.world.enabled: is_img2img = getattr(p, 'init_images', False)
logger.debug("extension is disabled") if is_img2img and self.world.enabled_i2i is False:
logger.debug("extension is disabled for i2i")
return
elif not is_img2img and self.world.enabled is False:
logger.debug("extension is disabled for t2i")
return return
self.world.update(p) self.world.update(p)
@ -352,7 +356,10 @@ class DistributedScript(scripts.Script):
return return
def postprocess(self, p, processed, *args): def postprocess(self, p, processed, *args):
if not self.world.enabled: is_img2img = getattr(p, 'init_images', False)
if is_img2img and self.world.enabled_i2i is False:
return
elif not is_img2img and self.world.enabled is False:
return return
if self.master_start is not None: if self.master_start is not None:

View File

@ -41,5 +41,6 @@ class ConfigModel(BaseModel):
) )
job_timeout: Optional[int] = Field(default=3) job_timeout: Optional[int] = Field(default=3)
enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True) enabled: Optional[bool] = Field(description="Whether the extension as a whole should be active or disabled", default=True)
enabled_i2i: Optional[bool] = Field(description="Same as above but for image to image", default=True)
complement_production: Optional[bool] = Field(description="Whether to generate complementary images to prevent under-utilizing hardware", default=True) complement_production: Optional[bool] = Field(description="Whether to generate complementary images to prevent under-utilizing hardware", default=True)
step_scaling: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints", default=False) step_scaling: Optional[bool] = Field(description="Whether to downscale requested steps in order to meet time constraints", default=False)

View File

@ -17,9 +17,10 @@ worker_select_dropdown = None
class UI: class UI:
"""extension user interface related things""" """extension user interface related things"""
def __init__(self, world): def __init__(self, world, is_img2img):
self.world = world self.world = world
self.original_model_dropdown_handler = opts.data_labels.get('sd_model_checkpoint').onchange self.original_model_dropdown_handler = opts.data_labels.get('sd_model_checkpoint').onchange
self.is_img2img = is_img2img
# handlers # handlers
@staticmethod @staticmethod
@ -186,11 +187,15 @@ class UI:
self.world.save_config() self.world.save_config()
def main_toggle_btn(self): def main_toggle_btn(self):
self.world.enabled = not self.world.enabled if self.is_img2img:
self.world.enabled_i2i = not self.world.enabled_i2i
else:
self.world.enabled = not self.world.enabled
self.world.save_config() self.world.save_config()
# restore vanilla sdwui handler for model dropdown if extension is disabled or inject if otherwise # restore vanilla sdwui handler for model dropdown if extension is disabled or inject if otherwise
if not self.world.enabled: if not self.world.enabled and not self.world.enabled_i2i:
model_dropdown = opts.data_labels.get('sd_model_checkpoint') model_dropdown = opts.data_labels.get('sd_model_checkpoint')
if self.original_model_dropdown_handler is not None: if self.original_model_dropdown_handler is not None:
model_dropdown.onchange = self.original_model_dropdown_handler model_dropdown.onchange = self.original_model_dropdown_handler
@ -209,9 +214,12 @@ class UI:
def create_ui(self): def create_ui(self):
"""creates the extension UI and returns relevant components""" """creates the extension UI and returns relevant components"""
components = [] components = []
elem_id = 'enabled'
if self.is_img2img:
elem_id += '_i2i'
with gradio.Blocks(variant='compact'): # Group() and Box() remove spacing with gradio.Blocks(variant='compact'): # Group() and Box() remove spacing
with InputAccordion(label='Distributed', open=False, value=self.world.config().get('enabled', False), elem_id='enable') as main_toggle: with InputAccordion(label='Distributed', open=False, value=self.world.config().get(elem_id), elem_id=elem_id) as main_toggle:
main_toggle.input(self.main_toggle_btn) main_toggle.input(self.main_toggle_btn)
setattr(main_toggle.accordion, 'do_not_save_to_config', True) # InputAccordion is really a CheckBox setattr(main_toggle.accordion, 'do_not_save_to_config', True) # InputAccordion is really a CheckBox
components.append(main_toggle) components.append(main_toggle)

View File

@ -91,6 +91,7 @@ class World:
self.verify_remotes = verify_remotes self.verify_remotes = verify_remotes
self.thin_client_mode = False self.thin_client_mode = False
self.enabled = True self.enabled = True
self.enabled_i2i = True
self.is_dropdown_handler_injected = False self.is_dropdown_handler_injected = False
self.complement_production = True self.complement_production = True
self.step_scaling = False self.step_scaling = False
@ -671,6 +672,7 @@ class World:
sh.benchmark_payload = Benchmark_Payload(**config.benchmark_payload) sh.benchmark_payload = Benchmark_Payload(**config.benchmark_payload)
self.job_timeout = config.job_timeout self.job_timeout = config.job_timeout
self.enabled = config.enabled self.enabled = config.enabled
self.enabled_i2i = config.enabled_i2i
self.complement_production = config.complement_production self.complement_production = config.complement_production
self.step_scaling = config.step_scaling self.step_scaling = config.step_scaling
@ -686,6 +688,7 @@ class World:
benchmark_payload=sh.benchmark_payload, benchmark_payload=sh.benchmark_payload,
job_timeout=self.job_timeout, job_timeout=self.job_timeout,
enabled=self.enabled, enabled=self.enabled,
enabled_i2i=self.enabled_i2i,
complement_production=self.complement_production, complement_production=self.complement_production,
step_scaling=self.step_scaling step_scaling=self.step_scaling
) )