diff --git a/scripts/sd_tag_batch.py b/scripts/sd_tag_batch.py index 3ca6a1c..c1835e8 100644 --- a/scripts/sd_tag_batch.py +++ b/scripts/sd_tag_batch.py @@ -51,7 +51,8 @@ class Script(scripts.Script): return NAME def show(self, is_img2img): - return scripts.AlwaysVisible if is_img2img else False + # return scripts.AlwaysVisible if is_img2img else False + return is_img2img def b_clicked(o): return gr.Button.update(interactive=True) @@ -100,82 +101,81 @@ class Script(scripts.Script): def ui(self, is_img2img): with gr.Group(): - with gr.Accordion(NAME, open=False): - model_options = ["CLIP (API)", "CLIP (Native)", "Deepbooru (Native)", "WD (API)"] - model_selection = gr.Dropdown(choices=model_options, label="Select Interrogation Model(s)", multiselect=True, value=None) + model_options = ["CLIP (API)", "CLIP (Native)", "Deepbooru (Native)", "WD (API)"] + model_selection = gr.Dropdown(choices=model_options, label="Select Interrogation Model(s)", multiselect=True, value=None) + + in_front = gr.Radio( + choices=["Prepend to prompt", "Append to prompt"], + value="Prepend to prompt", + label="Interrogator result position" + ) + + def update_prompt_weight_visibility(use_weight): + return gr.Slider.update(visible=use_weight) - in_front = gr.Radio( - choices=["Prepend to prompt", "Append to prompt"], - value="Prepend to prompt", - label="Interrogator result position" - ) - - def update_prompt_weight_visibility(use_weight): - return gr.Slider.update(visible=use_weight) + use_weight = gr.Checkbox(label="Use Interrogator Prompt Weight", value=True) + prompt_weight = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Interrogator Prompt Weight", visible=True) + + # CLIP API Options + def update_clip_api_visibility(model_selection): + is_visible = "CLIP (API)" in model_selection + if is_visible: + clip_models = self.load_clip_models() + return gr.Accordion.update(visible=True), clip_models + else: + return gr.Accordion.update(visible=False), gr.Dropdown.update() - use_weight = gr.Checkbox(label="Use Interrogator Prompt Weight", value=True) - prompt_weight = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Interrogator Prompt Weight", visible=True) + clip_api_accordion = gr.Accordion("CLIP API Options:", open=False, visible=False) + with clip_api_accordion: + clip_api_model = gr.Dropdown(choices=[], value='ViT-L-14/openai', label="CLIP API Model") + clip_api_mode = gr.Radio(choices=["best", "fast", "classic", "negative"], value='best', label="CLIP API Mode") - # CLIP API Options - def update_clip_api_visibility(model_selection): - is_visible = "CLIP (API)" in model_selection - if is_visible: - clip_models = self.load_clip_models() - return gr.Accordion.update(visible=True), clip_models - else: - return gr.Accordion.update(visible=False), gr.Dropdown.update() + # WD API Options + def update_wd_api_visibility(model_selection): + is_visible = "WD (API)" in model_selection + if is_visible: + wd_models = self.load_wd_models() + return gr.Accordion.update(visible=True), wd_models + else: + return gr.Accordion.update(visible=False), gr.Dropdown.update() + + wd_api_accordion = gr.Accordion("WD API Options:", open=False, visible=False) + with wd_api_accordion: + wd_api_model = gr.Dropdown(choices=[], value='wd-v1-4-moat-tagger.v2', label="WD API Model") + wd_threshold = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Threshold") + wd_underscore_fix = gr.Checkbox(label="Remove Underscores from Tags", value=True) + unload_wd_models_afterwords = gr.Checkbox(label="Unload WD Model After Use", value=True) + unload_wd_models_button = gr.Button(value="Unload WD Models") - clip_api_accordion = gr.Accordion("CLIP API Options:", open=False, visible=False) - with clip_api_accordion: - clip_api_model = gr.Dropdown(choices=[], value='ViT-L-14/openai', label="CLIP API Model") - clip_api_mode = gr.Radio(choices=["best", "fast", "classic", "negative"], label="CLIP API Mode", value="best") - - # WD API Options - def update_wd_api_visibility(model_selection): - is_visible = "WD (API)" in model_selection - if is_visible: - wd_models = self.load_wd_models() - return gr.Accordion.update(visible=True), wd_models - else: - return gr.Accordion.update(visible=False), gr.Dropdown.update() - - wd_api_accordion = gr.Accordion("WD API Options:", open=False, visible=False) - with wd_api_accordion: - wd_api_model = gr.Dropdown(choices=[], value='wd-v1-4-moat-tagger.v2', label="WD API Model") - wd_threshold = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Threshold") - wd_underscore_fix = gr.Checkbox(label="Remove Underscores from Tags", value=True) - unload_wd_models_afterwords = gr.Checkbox(label="Unload WD Model After Use", value=True) - unload_wd_models_button = gr.Button(value="Unload WD Models") - - # Function to load custom filter from file - def load_custom_filter(custom_filter): - with open("extensions/sd-Img2img-batch-interrogator/custom_filter.txt", "r") as file: - custom_filter = file.read() - return custom_filter - - with gr.Accordion("Filtering tools:"): - no_duplicates = gr.Checkbox(label="Filter Duplicate Prompt Content from Interrogation", value=False) - use_negatives = gr.Checkbox(label="Filter Negative Prompt Content from Interrogation", value=False) - use_custom_filter = gr.Checkbox(label="Filter Custom Prompt Content from Interrogation", value=False) - custom_filter = gr.Textbox( - label="Custom Filter Prompt", - placeholder="Prompt content separated by commas. Warning ignores attention syntax, parentheses '()' and colon suffix ':XX.XX' are discarded.", - show_copy_button=True - ) - # Button to load custom filter from file - load_custom_filter_button = gr.Button(value="Load Last Custom Filter") - - # Listeners - model_selection.select(fn=self.update_model_choices, inputs=[model_selection], outputs=[model_selection]) - model_selection.change(fn=update_clip_api_visibility, inputs=[model_selection], outputs=[clip_api_accordion, clip_api_model]) - model_selection.change(fn=update_wd_api_visibility, inputs=[model_selection], outputs=[wd_api_accordion, wd_api_model]) - load_custom_filter_button.click(load_custom_filter, inputs=custom_filter, outputs=custom_filter) - unload_wd_models_button.click(self.post_wd_api_unload, inputs=None, outputs=None) - use_weight.change(fn=update_prompt_weight_visibility, inputs=[use_weight], outputs=[prompt_weight]) - - + # Function to load custom filter from file + def load_custom_filter(custom_filter): + with open("extensions/sd-Img2img-batch-interrogator/custom_filter.txt", "r") as file: + custom_filter = file.read() + return custom_filter + + with gr.Accordion("Filtering tools:"): + no_duplicates = gr.Checkbox(label="Filter Duplicate Prompt Content from Interrogation", value=False) + use_negatives = gr.Checkbox(label="Filter Negative Prompt Content from Interrogation", value=False) + use_custom_filter = gr.Checkbox(label="Filter Custom Prompt Content from Interrogation", value=False) + custom_filter = gr.Textbox( + label="Custom Filter Prompt", + placeholder="Prompt content separated by commas. Warning ignores attention syntax, parentheses '()' and colon suffix ':XX.XX' are discarded.", + show_copy_button=True + ) + # Button to load custom filter from file + load_custom_filter_button = gr.Button(value="Load Last Custom Filter") + + # Listeners + model_selection.select(fn=self.update_model_choices, inputs=[model_selection], outputs=[model_selection]) + model_selection.change(fn=update_clip_api_visibility, inputs=[model_selection], outputs=[clip_api_accordion, clip_api_model]) + model_selection.change(fn=update_wd_api_visibility, inputs=[model_selection], outputs=[wd_api_accordion, wd_api_model]) + load_custom_filter_button.click(load_custom_filter, inputs=custom_filter, outputs=custom_filter) + unload_wd_models_button.click(self.post_wd_api_unload, inputs=None, outputs=None) + use_weight.change(fn=update_prompt_weight_visibility, inputs=[use_weight], outputs=[prompt_weight]) + + return [in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_api_model, clip_api_mode, wd_api_model, wd_threshold, wd_underscore_fix, unload_wd_models_afterwords] - + # Required to parse information from a string that is between () or has :##.## suffix def remove_attention(self, words): # Define a regular expression pattern to match attention-related suffixes @@ -219,68 +219,65 @@ class Script(scripts.Script): raw_prompt = p.prompt interrogator = "" - # If no model selected or no image, interrogation should not run - if model_selection and not p.init_images[0]: - # fix alpha channel - p.init_images[0] = p.init_images[0].convert("RGB") - - first = True # Two interrogator concatenation correction boolean - for model in model_selection: - # This prevents two interrogators from being incorrectly concatenated - if first == False: - interrogator += ", " - first = False - # Should add the interrogators in the order determined by the model_selection list - if model == "Deepbooru (Native)": - interrogator += deepbooru.model.tag(p.init_images[0]) - elif model == "CLIP (Native)": - interrogator += shared.interrogator.interrogate(p.init_images[0]) - elif model == "CLIP (API)": - interrogator += self.post_clip_api_prompt(p.init_images[0], clip_api_model, clip_api_mode) - elif model == "WD (API)": - interrogator += self.post_wd_api_tagger(p.init_images[0], wd_api_model, wd_threshold, wd_underscore_fix) - - # Remove duplicate prompt content from interrogator prompt - if no_duplicates: - interrogator = self.filter_words(interrogator, raw_prompt) - # Remove negative prompt content from interrogator prompt - if use_negatives: - interrogator = self.filter_words(interrogator, p.negative_prompt) - # Remove custom prompt content from interrogator prompt - if use_custom_filter: - interrogator = self.filter_words(interrogator, custom_filter) - # Save custom filter to text file - with open("extensions/sd-Img2img-batch-interrogator/custom_filter.txt", "w") as file: - file.write(custom_filter) - - if use_weight: - if p.prompt == "": - p.prompt = interrogator - elif in_front == "Append to prompt": - p.prompt = f"{p.prompt}, ({interrogator}:{prompt_weight})" - else: - p.prompt = f"({interrogator}:{prompt_weight}), {p.prompt}" + # fix alpha channel + p.init_images[0] = p.init_images[0].convert("RGB") + + first = True # Two interrogator concatenation correction boolean + for model in model_selection: + # This prevents two interrogators from being incorrectly concatenated + if first == False: + interrogator += ", " + first = False + # Should add the interrogators in the order determined by the model_selection list + if model == "Deepbooru (Native)": + interrogator += deepbooru.model.tag(p.init_images[0]) + elif model == "CLIP (Native)": + interrogator += shared.interrogator.interrogate(p.init_images[0]) + elif model == "CLIP (API)": + interrogator += self.post_clip_api_prompt(p.init_images[0], clip_api_model, clip_api_mode) + elif model == "WD (API)": + interrogator += self.post_wd_api_tagger(p.init_images[0], wd_api_model, wd_threshold, wd_underscore_fix) + + # Remove duplicate prompt content from interrogator prompt + if no_duplicates: + interrogator = self.filter_words(interrogator, raw_prompt) + # Remove negative prompt content from interrogator prompt + if use_negatives: + interrogator = self.filter_words(interrogator, p.negative_prompt) + # Remove custom prompt content from interrogator prompt + if use_custom_filter: + interrogator = self.filter_words(interrogator, custom_filter) + # Save custom filter to text file + with open("extensions/sd-Img2img-batch-interrogator/custom_filter.txt", "w") as file: + file.write(custom_filter) + + if use_weight: + if p.prompt == "": + p.prompt = interrogator + elif in_front == "Append to prompt": + p.prompt = f"{p.prompt}, ({interrogator}:{prompt_weight})" else: - if p.prompt == "": - p.prompt = interrogator - elif in_front == "Append to prompt": - p.prompt = f"{p.prompt}, {interrogator}" - else: - p.prompt = f"{interrogator}, {p.prompt}" - - if unload_wd_models_afterwords and "WD (API)" in model_selection: - self.post_wd_api_unload() - - print(f"Prompt: {p.prompt}") - + p.prompt = f"({interrogator}:{prompt_weight}), {p.prompt}" + else: + if p.prompt == "": + p.prompt = interrogator + elif in_front == "Append to prompt": + p.prompt = f"{p.prompt}, {interrogator}" + else: + p.prompt = f"{interrogator}, {p.prompt}" + + if unload_wd_models_afterwords and "WD (API)" in model_selection: + self.post_wd_api_unload() + + print(f"Prompt: {p.prompt}") + processed = process_images(p) - + # Restore the UI elements we modified p.prompt = raw_prompt - + return processed - # CLIP API Model Identification def get_clip_API_models(self): # Ensure CLIP Interrogator is present and accessible @@ -366,7 +363,7 @@ class Script(scripts.Script): # WARNING: Removing `timeout` could result in a frozen client if the queue_lock is locked. If you need more time add more time, do not remove timeout or risk DEADLOCK. # Note: If WD Tagger did not load a model, it is likely that WD Tagger specifically queue_lock (FIFOLock) is concerned with your system's threading and thinks running could cause processes starvation... # Note: It would be advisable to download models in the WD tab due to the timeout - response = requests.post(api_address, json=payload, timeout=120) + response = requests.post(api_address, json=payload, timeout=120) response.raise_for_status() result = response.json() @@ -381,16 +378,16 @@ class Script(scripts.Script): return tags_string except Exception as error: print(f"Error generating prompt with WD API: {error}") - return "" + return "" # WD API Model Unloader def post_wd_api_unload(self): try: api_address = f"{self.get_server_address()}tagger/v1/unload-interrogators" - response = requests.post(api_address, json='') + response = requests.post(api_address, json='') response.raise_for_status() except Exception as error: - print(f"Error Unloading models with WD API: {error}") + print(f"Error Unloading models with WD API: {error}") script_callbacks.on_app_started(Script.set_server_address)