diff --git a/scripts/sd_tag_batch.py b/scripts/sd_tag_batch.py index f12db06..b1d9ab2 100644 --- a/scripts/sd_tag_batch.py +++ b/scripts/sd_tag_batch.py @@ -84,19 +84,15 @@ class Script(scripts.Script): def b_clicked(o): return gr.Button.update(interactive=True) - - # Removes unsupported interrogators, support may vary depending on client - def update_model_choices(self, current_choices): - all_options = ["CLIP (EXT)", "CLIP (Native)", "Deepbooru (Native)", "WD (EXT)"] - if not is_interrogator_enabled('clip-interrogator-ext'): - all_options.remove("CLIP (EXT)") - if not is_interrogator_enabled('stable-diffusion-webui-wd14-tagger'): - all_options.remove("WD (EXT)") - - updated_choices = [choice for choice in current_choices if choice in all_options] - - return gr.Dropdown.update(choices=all_options, value=updated_choices) + # Initial Model Options generator, only add supported interrogators, support may vary depending on client + def get_initial_model_options(self): + options = ["CLIP (Native)", "Deepbooru (Native)"] + if is_interrogator_enabled('clip-interrogator-ext'): + options.insert(0, "CLIP (EXT)") + if is_interrogator_enabled('stable-diffusion-webui-wd14-tagger'): + options.append("WD (EXT)") + return options # Function to load CLIP models list into CLIP model selector def load_clip_models(self): @@ -139,8 +135,6 @@ class Script(scripts.Script): unloaded_models = unloaded_models + 1 print(f"Unloaded {unloaded_models} Tagger Model(s).") - - # depending on if CLIP (EXT) is present, CLIP (EXT) could be removed from model selector def update_clip_ext_visibility(self, model_selection): is_visible = "CLIP (EXT)" in model_selection @@ -171,8 +165,7 @@ class Script(scripts.Script): def ui(self, is_img2img): with gr.Group(): - model_options = ["CLIP (EXT)", "CLIP (Native)", "Deepbooru (Native)", "WD (EXT)"] - model_selection = gr.Dropdown(choices=model_options, + model_selection = gr.Dropdown(choices=self.get_initial_model_options(), label="Interrogation Model(s):", multiselect=True, value=None, @@ -216,7 +209,6 @@ class Script(scripts.Script): load_custom_filter_button = gr.Button(value="Load Last Custom Filter") # Listeners - model_selection.select(fn=self.update_model_choices, inputs=[model_selection], outputs=[model_selection]) model_selection.change(fn=self.update_clip_ext_visibility, inputs=[model_selection], outputs=[clip_ext_accordion, clip_ext_model]) model_selection.change(fn=self.update_wd_ext_visibility, inputs=[model_selection], outputs=[wd_ext_accordion, wd_ext_model]) load_custom_filter_button.click(self.load_custom_filter, inputs=custom_filter, outputs=custom_filter)