From 60bc36d5d9525f447a98bb0f4fc6587e547322fb Mon Sep 17 00:00:00 2001 From: Smirking Kitsune <36494751+SmirkingKitsune@users.noreply.github.com> Date: Thu, 25 Jul 2024 19:08:44 -0600 Subject: [PATCH] Find & Replace Update Adds a feature to find and replace words within the interrogation. --- scripts/sd_tag_batch.py | 180 +++++++++++++++++++++++++++++++++------- 1 file changed, 150 insertions(+), 30 deletions(-) diff --git a/scripts/sd_tag_batch.py b/scripts/sd_tag_batch.py index bce4b4a..b6864a9 100644 --- a/scripts/sd_tag_batch.py +++ b/scripts/sd_tag_batch.py @@ -1,6 +1,7 @@ import gradio as gr import re from modules import scripts, deepbooru, script_callbacks, shared +from modules.ui_components import InputAccordion from modules.processing import process_images from modules.shared import state import sys @@ -118,6 +119,12 @@ class Script(scripts.ScriptBuiltinUI): unique_items.append(item) # Join the cleaned, unique items back into a string return ', '.join(unique_items) + + # Custom replace function to replace phrases with associated pair + def custom_replace(self, text, replace_pairs): + for old, new in replace_pairs.items(): + text = re.sub(r'\b' + re.escape(old) + r'\b', new, text) + return text # Tag filtering, removes negative tags from prompt def filter_words(self, prompt, negative): @@ -174,7 +181,36 @@ class Script(scripts.ScriptBuiltinUI): return custom_filter except Exception as error: print(f"[{NAME} ERROR]: Error loading custom filter: {error}") + # This should be resolved by generating a blank file + if error == "[Errno 2] No such file or directory: 'extensions/sd-Img2img-batch-interrogator/custom_filter.txt'": + self.save_custom_filter("") return "" + + # Function used to prep custom filter environment with previously saved configuration + def load_custom_filter_on_start(self): + return self.load_custom_filter() + + # Function to load custom replace from file + def load_custom_replace(self): + try: + with open("extensions/sd-Img2img-batch-interrogator/custom_replace.txt", "r", encoding="utf-8") as file: + content = file.read().strip().split('\n') + if len(content) >= 2: + return content[0], content[1] + else: + print(f"[{NAME} ERROR]: Invalid custom replace file format.") + return "", "" + except Exception as error: + print(f"[{NAME} ERROR]: Error loading custom replace: {error}") + # This should be resolved by generating a blank file + if error == "[Errno 2] No such file or directory: 'extensions/sd-Img2img-batch-interrogator/custom_replace.txt'": + self.save_custom_replace("", "") + return "", "" + + # Function used to prep find and replace environment with previously saved configuration + def load_custom_replace_on_start(self): + old, new = self.load_custom_replace() + return old, new # Function to load WD models list into WD model selector def load_wd_models(self): @@ -183,6 +219,15 @@ class Script(scripts.ScriptBuiltinUI): return gr.Dropdown.update(choices=models if models else None) return gr.Dropdown.update(choices=None) + # Parse two strings to display pairs + def parse_replace_pairs(self, custom_replace_find, custom_replace_replacements): + old_list = [phrase.strip() for phrase in custom_replace_find.split(',')] + new_list = [phrase.strip() for phrase in custom_replace_replacements.split(',')] + + # Ensure both lists have the same length + min_length = min(len(old_list), len(new_list)) + return {old_list[i]: new_list[i] for i in range(min_length)} + # Refresh the model_selection dropdown def refresh_model_options(self): new_options = self.get_initial_model_options() @@ -260,7 +305,7 @@ class Script(scripts.ScriptBuiltinUI): self.debug_print(debug_mode, f"Reset was Called! The following prompt will be removed from the prompt_contamination cleaner: {self.prompt_contamination}") self.prompt_contamination = "" - # Function to load custom filter from file + # Function to save custom filter from file def save_custom_filter(self, custom_filter): try: with open("extensions/sd-Img2img-batch-interrogator/custom_filter.txt", "w", encoding="utf-8") as file: @@ -269,6 +314,16 @@ class Script(scripts.ScriptBuiltinUI): except Exception as error: print(f"[{NAME} ERROR]: Error saving custom filter: {error}") return self.update_save_confirmation_row_false() + + # Function to save custom replace from file + def save_custom_replace(self, custom_replace_find, custom_replace_replacements): + try: + with open("extensions/sd-Img2img-batch-interrogator/custom_replace.txt", "w", encoding="utf-8") as file: + file.write(f"{custom_replace_find}\n{custom_replace_replacements}") + print(f"[{NAME}]: Custom replace saved successfully.") + except Exception as error: + print(f"[{NAME} ERROR]: Error saving custom replace: {error}") + return self.update_save_confirmation_row_false() # depending on if CLIP (EXT) is present, CLIP (EXT) could be removed from model selector def update_clip_ext_visibility(self, model_selection): @@ -279,16 +334,33 @@ class Script(scripts.ScriptBuiltinUI): else: return gr.Accordion.update(visible=False), gr.Dropdown.update() - # Depending on if slider visible is enabled making the slider will be dynamically visible + # Updates the visibility of group with input bool making it dynamically visible + def update_group_visibility(self, user_defined_visibility): + return gr.Group.update(visible=user_defined_visibility) + + # Updates the visibility of slider with input bool making it dynamically visible def update_slider_visibility(self, user_defined_visibility): return gr.Slider.update(visible=user_defined_visibility) + # Makes save confirmation dialague invisible def update_save_confirmation_row_false(self): return gr.Accordion.update(visible=False) + # Makes save confirmation dialague visible def update_save_confirmation_row_true(self): return gr.Accordion.update(visible=True) + # Used for user visualization, (no longer used for parsing pairs) + def update_parsed_pairs(self, custom_replace_find, custom_replace_replacements): + old_list = [phrase.strip() for phrase in custom_replace_find.split(',')] + new_list = [phrase.strip() for phrase in custom_replace_replacements.split(',')] + + # Ensure both lists have the same length + min_length = min(len(old_list), len(new_list)) + pairs = [f"{old_list[i]}:{new_list[i]}" for i in range(min_length)] + + return ", ".join(pairs) + # depending on if WD (EXT) is present, WD (EXT) could be removed from model selector def update_wd_ext_visibility(self, model_selection): is_visible = "WD (EXT)" in model_selection @@ -315,9 +387,7 @@ class Script(scripts.ScriptBuiltinUI): def ui(self, is_img2img): if not is_img2img: return - #tag_batch_enabled = gr.Checkbox(label=NAME, value=False) - #tag_batch_ui = gr.Accordion(tag_batch_enabled, open=False) - #with tag_batch_ui: + with InputAccordion(False, label=NAME, elem_id="tag_batch_enabled") as tag_batch_enabled: with gr.Row(): model_selection = gr.Dropdown( @@ -353,25 +423,61 @@ class Script(scripts.ScriptBuiltinUI): filtering_tools = gr.Accordion("Filtering tools:") with filtering_tools: - use_positive_filter = gr.Checkbox(label="Filter Duplicate Positive Prompt Content from Interrogation") - use_negative_filter = gr.Checkbox(label="Filter Duplicate Negative Prompt Content from Interrogation") - use_custom_filter = gr.Checkbox(label="Filter Custom Prompt Content from Interrogation") - custom_filter = gr.Textbox(value=self.load_custom_filter(), - label="Custom Filter Prompt", - placeholder="Prompt content separated by commas. Warning ignores attention syntax, parentheses '()' and colon suffix ':XX.XX' are discarded.", - show_copy_button=True - ) - # Button to remove duplicates and strip strange spacing - clean_custom_filter_button = gr.Button(value="Optimize Custom Filter") - # Button to load/save custom filter from file - with gr.Row(): - load_custom_filter_button = gr.Button(value="Load Custom Filter") - save_confirmation_button = gr.Button(value="Save Custom Filter") - save_confirmation_row = gr.Accordion("Are You Sure You Want to Save?", visible=False) - with save_confirmation_row: + use_positive_filter = gr.Checkbox(label="Filter Duplicate Positive Prompt Entries from Interrogation") + use_negative_filter = gr.Checkbox(label="Filter Duplicate Negative Prompt Entries from Interrogation") + use_custom_filter = gr.Checkbox(label="Filter Custom Prompt Entries from Interrogation") + custom_filter_group = gr.Group(visible=False) + with custom_filter_group: + custom_filter = gr.Textbox(value=self.load_custom_filter_on_start(), + label="Custom Filter Prompt", + placeholder="Prompt content separated by commas. Warning ignores attention syntax, parentheses '()' and colon suffix ':XX.XX' are discarded.", + show_copy_button=True + ) + # Button to remove duplicates and strip strange spacing + clean_custom_filter_button = gr.Button(value="Optimize Custom Filter") + # Button to load/save custom filter from file with gr.Row(): - cancel_save_button = gr.Button(value="Cancel") - save_custom_filter_button = gr.Button(value="Save", variant="stop") + load_custom_filter_button = gr.Button(value="Load Custom Filter") + save_custom_filter_button = gr.Button(value="Save Custom Filter") + save_confirmation_custom_filter = gr.Accordion("Are You Sure You Want to Save?", visible=False) + with save_confirmation_custom_filter: + with gr.Row(): + cancel_save_custom_filter_button = gr.Button(value="Cancel") + confirm_save_custom_filter_button = gr.Button(value="Save", variant="stop") + + # Find and Replace + use_custom_replace = gr.Checkbox(label="Find & Replace User Defined Pairs in the Interrogation") + custom_replace_group = gr.Group(visible=False) + with custom_replace_group: + with gr.Row(): + custom_replace_find = gr.Textbox( + value=self.load_custom_replace_on_start()[0], + label="Find:", + placeholder="Enter phrases to replace, separated by commas", + show_copy_button=True + ) + custom_replace_replacements = gr.Textbox( + value=self.load_custom_replace_on_start()[1], + label="Replace:", + placeholder="Enter replacement phrases, separated by commas", + show_copy_button=True + ) + with gr.Row(): + parsed_pairs = gr.Textbox( + label="Parsed Pairs", + placeholder="Parsed pairs will be shown here", + interactive=False + ) + update_parsed_pairs_button = gr.Button("🔄", elem_classes="tool") + with gr.Row(): + load_custom_replace_button = gr.Button("Load Custom Replace") + save_custom_replace_button = gr.Button("Save Custom Replace") + save_confirmation_custom_replace = gr.Accordion("Are You Sure You Want to Save?", visible=False) + with save_confirmation_custom_replace: + with gr.Row(): + cancel_save_custom_replace_button = gr.Button(value="Cancel") + confirm_save_custom_replace_button = gr.Button(value="Save", variant="stop") + experimental_tools = gr.Accordion("Experamental tools:", open=False) with experimental_tools: @@ -392,22 +498,31 @@ class Script(scripts.ScriptBuiltinUI): wd_append_ratings.change(fn=self.update_slider_visibility, inputs=[wd_append_ratings], outputs=[wd_ratings]) clean_custom_filter_button.click(self.clean_string, inputs=custom_filter, outputs=custom_filter) load_custom_filter_button.click(self.load_custom_filter, inputs=None, outputs=custom_filter) - save_confirmation_button.click(self.update_save_confirmation_row_true, inputs=None, outputs=[save_confirmation_row]) - cancel_save_button.click(self.update_save_confirmation_row_false, inputs=None, outputs=[save_confirmation_row]) - save_custom_filter_button.click(self.save_custom_filter, inputs=custom_filter, outputs=[save_confirmation_row]) + custom_replace_find.change(fn=self.update_parsed_pairs, inputs=[custom_replace_find, custom_replace_replacements], outputs=[parsed_pairs]) + custom_replace_replacements.change(fn=self.update_parsed_pairs, inputs=[custom_replace_find, custom_replace_replacements], outputs=[parsed_pairs]) + update_parsed_pairs_button.click(fn=self.update_parsed_pairs, inputs=[custom_replace_find, custom_replace_replacements], outputs=[parsed_pairs]) + save_custom_filter_button.click(self.update_save_confirmation_row_true, inputs=None, outputs=[save_confirmation_custom_filter]) + cancel_save_custom_filter_button.click(self.update_save_confirmation_row_false, inputs=None, outputs=[save_confirmation_custom_filter]) + confirm_save_custom_filter_button.click(self.save_custom_filter, inputs=custom_filter, outputs=[save_confirmation_custom_filter]) + load_custom_replace_button.click(fn=self.load_custom_replace, inputs=[],outputs=[custom_replace_find, custom_replace_replacements]) + save_custom_replace_button.click(fn=self.update_save_confirmation_row_true, inputs=[], outputs=[save_confirmation_custom_replace]) + cancel_save_custom_replace_button.click(fn=self.update_save_confirmation_row_false, inputs=[], outputs=[save_confirmation_custom_replace]) + confirm_save_custom_replace_button.click(fn=self.save_custom_replace, inputs=[custom_replace_find, custom_replace_replacements], outputs=[save_confirmation_custom_replace]) refresh_models_button.click(fn=self.refresh_model_options, inputs=[], outputs=[model_selection]) + use_custom_filter.change(fn=self.update_group_visibility, inputs=[use_custom_filter], outputs=[custom_filter_group]) + use_custom_replace.change(fn=self.update_group_visibility, inputs=[use_custom_replace], outputs=[custom_replace_group]) ui = [ tag_batch_enabled, model_selection, debug_mode, in_front, prompt_weight_mode, prompt_weight, reverse_mode, exaggeration_mode, prompt_output, use_positive_filter, use_negative_filter, - use_custom_filter, custom_filter, clip_ext_model, clip_ext_mode, wd_ext_model, wd_threshold, wd_underscore_fix, wd_append_ratings, wd_ratings, unload_clip_models_afterwords, unload_wd_models_afterwords, - no_puncuation_mode + use_custom_filter, custom_filter, use_custom_replace, custom_replace_find, custom_replace_replacements, clip_ext_model, clip_ext_mode, wd_ext_model, wd_threshold, wd_underscore_fix, wd_append_ratings, wd_ratings, + unload_clip_models_afterwords, unload_wd_models_afterwords, no_puncuation_mode ] return ui def process_batch( self, p, tag_batch_enabled, model_selection, debug_mode, in_front, prompt_weight_mode, prompt_weight, reverse_mode, exaggeration_mode, prompt_output, use_positive_filter, use_negative_filter, - use_custom_filter, custom_filter, clip_ext_model, clip_ext_mode, wd_ext_model, wd_threshold, wd_underscore_fix, wd_append_ratings, wd_ratings, unload_clip_models_afterwords, unload_wd_models_afterwords, - no_puncuation_mode, batch_number, prompts, seeds, subseeds): + use_custom_filter, custom_filter, use_custom_replace, custom_replace_find, custom_replace_replacements, clip_ext_model, clip_ext_mode, wd_ext_model, wd_threshold, wd_underscore_fix, wd_append_ratings, wd_ratings, + unload_clip_models_afterwords, unload_wd_models_afterwords, no_puncuation_mode, batch_number, prompts, seeds, subseeds): if not tag_batch_enabled: return @@ -522,6 +637,11 @@ class Script(scripts.ScriptBuiltinUI): if not exaggeration_mode: interrogation = self.clean_string(interrogation) + # Find and Replace user defined words in the interrogation prompt + if use_custom_replace: + replace_pairs = self.parse_replace_pairs(custom_replace_find, custom_replace_replacements) + interrogation = self.custom_replace(interrogation, replace_pairs) + # Remove duplicate prompt content from interrogator prompt if use_positive_filter: interrogation = self.filter_words(interrogation, p.prompt)