From 6b1b47f3929bfebeffecf2f81dcf9d494f6e726d Mon Sep 17 00:00:00 2001 From: Smirking Kitsune <36494751+SmirkingKitsune@users.noreply.github.com> Date: Thu, 27 Jun 2024 17:44:48 -0700 Subject: [PATCH] Always Visible UI Overhaul UI Changes: - Made extension always visible on the img2img tab. This will allow for other scripts to be run alongside extension. - Added triggers to determine if the script should run. Triggers are if a model is selected or if an image is present. - Removed "Deepbooru (Native)" from being the default model, so that script is not active by default. - Set `clip_api_mode` to "best" to better align with 'clip-interrogator-ext' defaults Optimizations: - Library declarations needed to be organized and cleaned of repetition. - Script will no longer run if there is no image to interrogate, (I think that img2img probably should not run if there is no images anyways...) --- scripts/sd_tag_batch.py | 259 +++++++++++++++++++++------------------- 1 file changed, 133 insertions(+), 126 deletions(-) diff --git a/scripts/sd_tag_batch.py b/scripts/sd_tag_batch.py index 5b37898..3ca6a1c 100644 --- a/scripts/sd_tag_batch.py +++ b/scripts/sd_tag_batch.py @@ -1,13 +1,14 @@ import gradio as gr import re -from modules import scripts, deepbooru -from modules.processing import process_images -import modules.shared as shared import os import requests from io import BytesIO import base64 -from modules import script_callbacks +from modules import scripts, deepbooru, script_callbacks +from modules.processing import process_images +import modules.shared as shared + +NAME = "Img2img batch interrogator" """ @@ -30,8 +31,9 @@ class Script(scripts.Script): return cls.server_address # Fallback to the brute force method if server_address is not set - # Initial testing indicates that fallback method might not be needed... + # Initial testing indicates that fallback method will never be used... print("Server address not set. Falling back to brute force method.") + # Fallback is highly inefficient and in some cases slow (especially if expected port is far from default) ports = range(7860, 7960) # Gradio will increment port 100 times if default and subsequent desired ports are unavailable. for port in ports: url = f"http://127.0.0.1:{port}/" @@ -42,14 +44,14 @@ class Script(scripts.Script): except requests.RequestException as error: print(f"API not available on port {port}: {error}") - print("API not found on any port") + print("API not found") return None def title(self): - return "Img2img batch interrogator" + return NAME def show(self, is_img2img): - return is_img2img + return scripts.AlwaysVisible if is_img2img else False def b_clicked(o): return gr.Button.update(interactive=True) @@ -97,77 +99,79 @@ class Script(scripts.Script): return gr.Dropdown.update(choices=models if models else None) def ui(self, is_img2img): - model_options = ["CLIP (API)", "CLIP (Native)", "Deepbooru (Native)", "WD (API)"] - model_selection = gr.Dropdown(choices=model_options, label="Select Interrogation Model(s)", multiselect=True, value="Deepbooru (Native)") - - in_front = gr.Radio( - choices=["Prepend to prompt", "Append to prompt"], - value="Prepend to prompt", - label="Interrogator result position" - ) - - def update_prompt_weight_visibility(use_weight): - return gr.Slider.update(visible=use_weight) - - use_weight = gr.Checkbox(label="Use Interrogator Prompt Weight", value=True) - prompt_weight = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Interrogator Prompt Weight", visible=True) - - # CLIP API Options - def update_clip_api_visibility(model_selection): - is_visible = "CLIP (API)" in model_selection - if is_visible: - clip_models = self.load_clip_models() - return gr.Accordion.update(visible=True), clip_models - else: - return gr.Accordion.update(visible=False), gr.Dropdown.update() - - clip_api_accordion = gr.Accordion("CLIP API Options:", open=False, visible=False) - with clip_api_accordion: - clip_api_model = gr.Dropdown(choices=[], value='ViT-L-14/openai', label="CLIP API Model") - clip_api_mode = gr.Radio(choices=["fast", "best", "classic", "negative"], label="CLIP API Mode", value="fast") + with gr.Group(): + with gr.Accordion(NAME, open=False): + model_options = ["CLIP (API)", "CLIP (Native)", "Deepbooru (Native)", "WD (API)"] + model_selection = gr.Dropdown(choices=model_options, label="Select Interrogation Model(s)", multiselect=True, value=None) + + in_front = gr.Radio( + choices=["Prepend to prompt", "Append to prompt"], + value="Prepend to prompt", + label="Interrogator result position" + ) + + def update_prompt_weight_visibility(use_weight): + return gr.Slider.update(visible=use_weight) + + use_weight = gr.Checkbox(label="Use Interrogator Prompt Weight", value=True) + prompt_weight = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Interrogator Prompt Weight", visible=True) + + # CLIP API Options + def update_clip_api_visibility(model_selection): + is_visible = "CLIP (API)" in model_selection + if is_visible: + clip_models = self.load_clip_models() + return gr.Accordion.update(visible=True), clip_models + else: + return gr.Accordion.update(visible=False), gr.Dropdown.update() + + clip_api_accordion = gr.Accordion("CLIP API Options:", open=False, visible=False) + with clip_api_accordion: + clip_api_model = gr.Dropdown(choices=[], value='ViT-L-14/openai', label="CLIP API Model") + clip_api_mode = gr.Radio(choices=["best", "fast", "classic", "negative"], label="CLIP API Mode", value="best") - # WD API Options - def update_wd_api_visibility(model_selection): - is_visible = "WD (API)" in model_selection - if is_visible: - wd_models = self.load_wd_models() - return gr.Accordion.update(visible=True), wd_models - else: - return gr.Accordion.update(visible=False), gr.Dropdown.update() + # WD API Options + def update_wd_api_visibility(model_selection): + is_visible = "WD (API)" in model_selection + if is_visible: + wd_models = self.load_wd_models() + return gr.Accordion.update(visible=True), wd_models + else: + return gr.Accordion.update(visible=False), gr.Dropdown.update() - wd_api_accordion = gr.Accordion("WD API Options:", open=False, visible=False) - with wd_api_accordion: - wd_api_model = gr.Dropdown(choices=[], value='wd-v1-4-moat-tagger.v2', label="WD API Model") - wd_underscore_fix = gr.Checkbox(label="Remove Underscores from Tags", value=True) - wd_threshold = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Threshold") - unload_wd_models_afterwords = gr.Checkbox(label="Unload WD Model After Use", value=True) - unload_wd_models_button = gr.Button(value="Unload WD Models") + wd_api_accordion = gr.Accordion("WD API Options:", open=False, visible=False) + with wd_api_accordion: + wd_api_model = gr.Dropdown(choices=[], value='wd-v1-4-moat-tagger.v2', label="WD API Model") + wd_threshold = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Threshold") + wd_underscore_fix = gr.Checkbox(label="Remove Underscores from Tags", value=True) + unload_wd_models_afterwords = gr.Checkbox(label="Unload WD Model After Use", value=True) + unload_wd_models_button = gr.Button(value="Unload WD Models") - # Function to load custom filter from file - def load_custom_filter(custom_filter): - with open("extensions/sd-Img2img-batch-interrogator/custom_filter.txt", "r") as file: - custom_filter = file.read() - return custom_filter + # Function to load custom filter from file + def load_custom_filter(custom_filter): + with open("extensions/sd-Img2img-batch-interrogator/custom_filter.txt", "r") as file: + custom_filter = file.read() + return custom_filter - with gr.Accordion("Filtering tools:"): - no_duplicates = gr.Checkbox(label="Filter Duplicate Prompt Content from Interrogation", value=False) - use_negatives = gr.Checkbox(label="Filter Negative Prompt Content from Interrogation", value=False) - use_custom_filter = gr.Checkbox(label="Filter Custom Prompt Content from Interrogation", value=False) - custom_filter = gr.Textbox( - label="Custom Filter Prompt", - placeholder="Prompt content separated by commas. Warning ignores attention syntax, parentheses '()' and colon suffix ':XX.XX' are discarded.", - show_copy_button=True - ) - # Button to load custom filter from file - load_custom_filter_button = gr.Button(value="Load Last Custom Filter") + with gr.Accordion("Filtering tools:"): + no_duplicates = gr.Checkbox(label="Filter Duplicate Prompt Content from Interrogation", value=False) + use_negatives = gr.Checkbox(label="Filter Negative Prompt Content from Interrogation", value=False) + use_custom_filter = gr.Checkbox(label="Filter Custom Prompt Content from Interrogation", value=False) + custom_filter = gr.Textbox( + label="Custom Filter Prompt", + placeholder="Prompt content separated by commas. Warning ignores attention syntax, parentheses '()' and colon suffix ':XX.XX' are discarded.", + show_copy_button=True + ) + # Button to load custom filter from file + load_custom_filter_button = gr.Button(value="Load Last Custom Filter") - # Listeners - model_selection.select(fn=self.update_model_choices, inputs=[model_selection], outputs=[model_selection]) - model_selection.change(fn=update_clip_api_visibility, inputs=[model_selection], outputs=[clip_api_accordion, clip_api_model]) - model_selection.change(fn=update_wd_api_visibility, inputs=[model_selection], outputs=[wd_api_accordion, wd_api_model]) - load_custom_filter_button.click(load_custom_filter, inputs=custom_filter, outputs=custom_filter) - unload_wd_models_button.click(self.post_wd_api_unload, inputs=None, outputs=None) - use_weight.change(fn=update_prompt_weight_visibility, inputs=[use_weight], outputs=[prompt_weight]) + # Listeners + model_selection.select(fn=self.update_model_choices, inputs=[model_selection], outputs=[model_selection]) + model_selection.change(fn=update_clip_api_visibility, inputs=[model_selection], outputs=[clip_api_accordion, clip_api_model]) + model_selection.change(fn=update_wd_api_visibility, inputs=[model_selection], outputs=[wd_api_accordion, wd_api_model]) + load_custom_filter_button.click(load_custom_filter, inputs=custom_filter, outputs=custom_filter) + unload_wd_models_button.click(self.post_wd_api_unload, inputs=None, outputs=None) + use_weight.change(fn=update_prompt_weight_visibility, inputs=[use_weight], outputs=[prompt_weight]) return [in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_api_model, clip_api_mode, wd_api_model, wd_threshold, wd_underscore_fix, unload_wd_models_afterwords] @@ -214,58 +218,60 @@ class Script(scripts.Script): def run(self, p, in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_api_model, clip_api_mode, wd_api_model, wd_threshold, wd_underscore_fix, unload_wd_models_afterwords): raw_prompt = p.prompt interrogator = "" + + # If no model selected or no image, interrogation should not run + if model_selection and not p.init_images[0]: + # fix alpha channel + p.init_images[0] = p.init_images[0].convert("RGB") + + first = True # Two interrogator concatenation correction boolean + for model in model_selection: + # This prevents two interrogators from being incorrectly concatenated + if first == False: + interrogator += ", " + first = False + # Should add the interrogators in the order determined by the model_selection list + if model == "Deepbooru (Native)": + interrogator += deepbooru.model.tag(p.init_images[0]) + elif model == "CLIP (Native)": + interrogator += shared.interrogator.interrogate(p.init_images[0]) + elif model == "CLIP (API)": + interrogator += self.post_clip_api_prompt(p.init_images[0], clip_api_model, clip_api_mode) + elif model == "WD (API)": + interrogator += self.post_wd_api_tagger(p.init_images[0], wd_api_model, wd_threshold, wd_underscore_fix) + + # Remove duplicate prompt content from interrogator prompt + if no_duplicates: + interrogator = self.filter_words(interrogator, raw_prompt) + # Remove negative prompt content from interrogator prompt + if use_negatives: + interrogator = self.filter_words(interrogator, p.negative_prompt) + # Remove custom prompt content from interrogator prompt + if use_custom_filter: + interrogator = self.filter_words(interrogator, custom_filter) + # Save custom filter to text file + with open("extensions/sd-Img2img-batch-interrogator/custom_filter.txt", "w") as file: + file.write(custom_filter) - # fix alpha channel - p.init_images[0] = p.init_images[0].convert("RGB") - - first = True # Two interrogator concatenation correction boolean - for model in model_selection: - # This prevents two interrogators from being incorrectly concatenated - if first == False: - interrogator += ", " - first = False - # Should add the interrogators in the order determined by the model_selection list - if model == "Deepbooru (Native)": - interrogator += deepbooru.model.tag(p.init_images[0]) - elif model == "CLIP (Native)": - interrogator += shared.interrogator.interrogate(p.init_images[0]) - elif model == "CLIP (API)": - interrogator += self.post_clip_api_prompt(p.init_images[0], clip_api_model, clip_api_mode) - elif model == "WD (API)": - interrogator += self.post_wd_api_tagger(p.init_images[0], wd_api_model, wd_threshold, wd_underscore_fix) - - # Remove duplicate prompt content from interrogator prompt - if no_duplicates: - interrogator = self.filter_words(interrogator, raw_prompt) - # Remove negative prompt content from interrogator prompt - if use_negatives: - interrogator = self.filter_words(interrogator, p.negative_prompt) - # Remove custom prompt content from interrogator prompt - if use_custom_filter: - interrogator = self.filter_words(interrogator, custom_filter) - # Save custom filter to text file - with open("extensions/sd-Img2img-batch-interrogator/custom_filter.txt", "w") as file: - file.write(custom_filter) - - if use_weight: - if p.prompt == "": - p.prompt = interrogator - elif in_front == "Append to prompt": - p.prompt = f"{p.prompt}, ({interrogator}:{prompt_weight})" + if use_weight: + if p.prompt == "": + p.prompt = interrogator + elif in_front == "Append to prompt": + p.prompt = f"{p.prompt}, ({interrogator}:{prompt_weight})" + else: + p.prompt = f"({interrogator}:{prompt_weight}), {p.prompt}" else: - p.prompt = f"({interrogator}:{prompt_weight}), {p.prompt}" - else: - if p.prompt == "": - p.prompt = interrogator - elif in_front == "Append to prompt": - p.prompt = f"{p.prompt}, {interrogator}" - else: - p.prompt = f"{interrogator}, {p.prompt}" - - if unload_wd_models_afterwords and "WD (API)" in model_selection: - self.post_wd_api_unload() - - print(f"Prompt: {p.prompt}") + if p.prompt == "": + p.prompt = interrogator + elif in_front == "Append to prompt": + p.prompt = f"{p.prompt}, {interrogator}" + else: + p.prompt = f"{interrogator}, {p.prompt}" + + if unload_wd_models_afterwords and "WD (API)" in model_selection: + self.post_wd_api_unload() + + print(f"Prompt: {p.prompt}") processed = process_images(p) @@ -357,8 +363,9 @@ class Script(scripts.Script): "name_in_queue": "" } api_address = f"{self.get_server_address()}tagger/v1/interrogate" - # WARNING: Removing `timeout` could result in a frozen client if the queue_lock is locked. If you need more time add more time, do not remove or risk DEADLOCK. - # Note: If WD Tagger did not load a model, it is likely that WD Tagger specifically queue_lock (FIFOLock) is concerned with your system's threading and thinks running could cause process starvation... + # WARNING: Removing `timeout` could result in a frozen client if the queue_lock is locked. If you need more time add more time, do not remove timeout or risk DEADLOCK. + # Note: If WD Tagger did not load a model, it is likely that WD Tagger specifically queue_lock (FIFOLock) is concerned with your system's threading and thinks running could cause processes starvation... + # Note: It would be advisable to download models in the WD tab due to the timeout response = requests.post(api_address, json=payload, timeout=120) response.raise_for_status() result = response.json()