Find & Replace Update

Adds a feature to find and replace words within the interrogation.
pull/10/head
Smirking Kitsune 2024-07-25 19:08:44 -06:00 committed by GitHub
parent 919f4cfb31
commit 60bc36d5d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 150 additions and 30 deletions

View File

@ -1,6 +1,7 @@
import gradio as gr import gradio as gr
import re import re
from modules import scripts, deepbooru, script_callbacks, shared from modules import scripts, deepbooru, script_callbacks, shared
from modules.ui_components import InputAccordion
from modules.processing import process_images from modules.processing import process_images
from modules.shared import state from modules.shared import state
import sys import sys
@ -118,6 +119,12 @@ class Script(scripts.ScriptBuiltinUI):
unique_items.append(item) unique_items.append(item)
# Join the cleaned, unique items back into a string # Join the cleaned, unique items back into a string
return ', '.join(unique_items) return ', '.join(unique_items)
# Custom replace function to replace phrases with associated pair
def custom_replace(self, text, replace_pairs):
for old, new in replace_pairs.items():
text = re.sub(r'\b' + re.escape(old) + r'\b', new, text)
return text
# Tag filtering, removes negative tags from prompt # Tag filtering, removes negative tags from prompt
def filter_words(self, prompt, negative): def filter_words(self, prompt, negative):
@ -174,7 +181,36 @@ class Script(scripts.ScriptBuiltinUI):
return custom_filter return custom_filter
except Exception as error: except Exception as error:
print(f"[{NAME} ERROR]: Error loading custom filter: {error}") print(f"[{NAME} ERROR]: Error loading custom filter: {error}")
# This should be resolved by generating a blank file
if error == "[Errno 2] No such file or directory: 'extensions/sd-Img2img-batch-interrogator/custom_filter.txt'":
self.save_custom_filter("")
return "" return ""
# Function used to prep custom filter environment with previously saved configuration
def load_custom_filter_on_start(self):
return self.load_custom_filter()
# Function to load custom replace from file
def load_custom_replace(self):
try:
with open("extensions/sd-Img2img-batch-interrogator/custom_replace.txt", "r", encoding="utf-8") as file:
content = file.read().strip().split('\n')
if len(content) >= 2:
return content[0], content[1]
else:
print(f"[{NAME} ERROR]: Invalid custom replace file format.")
return "", ""
except Exception as error:
print(f"[{NAME} ERROR]: Error loading custom replace: {error}")
# This should be resolved by generating a blank file
if error == "[Errno 2] No such file or directory: 'extensions/sd-Img2img-batch-interrogator/custom_replace.txt'":
self.save_custom_replace("", "")
return "", ""
# Function used to prep find and replace environment with previously saved configuration
def load_custom_replace_on_start(self):
old, new = self.load_custom_replace()
return old, new
# Function to load WD models list into WD model selector # Function to load WD models list into WD model selector
def load_wd_models(self): def load_wd_models(self):
@ -183,6 +219,15 @@ class Script(scripts.ScriptBuiltinUI):
return gr.Dropdown.update(choices=models if models else None) return gr.Dropdown.update(choices=models if models else None)
return gr.Dropdown.update(choices=None) return gr.Dropdown.update(choices=None)
# Parse two strings to display pairs
def parse_replace_pairs(self, custom_replace_find, custom_replace_replacements):
old_list = [phrase.strip() for phrase in custom_replace_find.split(',')]
new_list = [phrase.strip() for phrase in custom_replace_replacements.split(',')]
# Ensure both lists have the same length
min_length = min(len(old_list), len(new_list))
return {old_list[i]: new_list[i] for i in range(min_length)}
# Refresh the model_selection dropdown # Refresh the model_selection dropdown
def refresh_model_options(self): def refresh_model_options(self):
new_options = self.get_initial_model_options() new_options = self.get_initial_model_options()
@ -260,7 +305,7 @@ class Script(scripts.ScriptBuiltinUI):
self.debug_print(debug_mode, f"Reset was Called! The following prompt will be removed from the prompt_contamination cleaner: {self.prompt_contamination}") self.debug_print(debug_mode, f"Reset was Called! The following prompt will be removed from the prompt_contamination cleaner: {self.prompt_contamination}")
self.prompt_contamination = "" self.prompt_contamination = ""
# Function to load custom filter from file # Function to save custom filter from file
def save_custom_filter(self, custom_filter): def save_custom_filter(self, custom_filter):
try: try:
with open("extensions/sd-Img2img-batch-interrogator/custom_filter.txt", "w", encoding="utf-8") as file: with open("extensions/sd-Img2img-batch-interrogator/custom_filter.txt", "w", encoding="utf-8") as file:
@ -269,6 +314,16 @@ class Script(scripts.ScriptBuiltinUI):
except Exception as error: except Exception as error:
print(f"[{NAME} ERROR]: Error saving custom filter: {error}") print(f"[{NAME} ERROR]: Error saving custom filter: {error}")
return self.update_save_confirmation_row_false() return self.update_save_confirmation_row_false()
# Function to save custom replace from file
def save_custom_replace(self, custom_replace_find, custom_replace_replacements):
try:
with open("extensions/sd-Img2img-batch-interrogator/custom_replace.txt", "w", encoding="utf-8") as file:
file.write(f"{custom_replace_find}\n{custom_replace_replacements}")
print(f"[{NAME}]: Custom replace saved successfully.")
except Exception as error:
print(f"[{NAME} ERROR]: Error saving custom replace: {error}")
return self.update_save_confirmation_row_false()
# depending on if CLIP (EXT) is present, CLIP (EXT) could be removed from model selector # depending on if CLIP (EXT) is present, CLIP (EXT) could be removed from model selector
def update_clip_ext_visibility(self, model_selection): def update_clip_ext_visibility(self, model_selection):
@ -279,16 +334,33 @@ class Script(scripts.ScriptBuiltinUI):
else: else:
return gr.Accordion.update(visible=False), gr.Dropdown.update() return gr.Accordion.update(visible=False), gr.Dropdown.update()
# Depending on if slider visible is enabled making the slider will be dynamically visible # Updates the visibility of group with input bool making it dynamically visible
def update_group_visibility(self, user_defined_visibility):
return gr.Group.update(visible=user_defined_visibility)
# Updates the visibility of slider with input bool making it dynamically visible
def update_slider_visibility(self, user_defined_visibility): def update_slider_visibility(self, user_defined_visibility):
return gr.Slider.update(visible=user_defined_visibility) return gr.Slider.update(visible=user_defined_visibility)
# Makes save confirmation dialague invisible
def update_save_confirmation_row_false(self): def update_save_confirmation_row_false(self):
return gr.Accordion.update(visible=False) return gr.Accordion.update(visible=False)
# Makes save confirmation dialague visible
def update_save_confirmation_row_true(self): def update_save_confirmation_row_true(self):
return gr.Accordion.update(visible=True) return gr.Accordion.update(visible=True)
# Used for user visualization, (no longer used for parsing pairs)
def update_parsed_pairs(self, custom_replace_find, custom_replace_replacements):
old_list = [phrase.strip() for phrase in custom_replace_find.split(',')]
new_list = [phrase.strip() for phrase in custom_replace_replacements.split(',')]
# Ensure both lists have the same length
min_length = min(len(old_list), len(new_list))
pairs = [f"{old_list[i]}:{new_list[i]}" for i in range(min_length)]
return ", ".join(pairs)
# depending on if WD (EXT) is present, WD (EXT) could be removed from model selector # depending on if WD (EXT) is present, WD (EXT) could be removed from model selector
def update_wd_ext_visibility(self, model_selection): def update_wd_ext_visibility(self, model_selection):
is_visible = "WD (EXT)" in model_selection is_visible = "WD (EXT)" in model_selection
@ -315,9 +387,7 @@ class Script(scripts.ScriptBuiltinUI):
def ui(self, is_img2img): def ui(self, is_img2img):
if not is_img2img: if not is_img2img:
return return
#tag_batch_enabled = gr.Checkbox(label=NAME, value=False)
#tag_batch_ui = gr.Accordion(tag_batch_enabled, open=False)
#with tag_batch_ui:
with InputAccordion(False, label=NAME, elem_id="tag_batch_enabled") as tag_batch_enabled: with InputAccordion(False, label=NAME, elem_id="tag_batch_enabled") as tag_batch_enabled:
with gr.Row(): with gr.Row():
model_selection = gr.Dropdown( model_selection = gr.Dropdown(
@ -353,25 +423,61 @@ class Script(scripts.ScriptBuiltinUI):
filtering_tools = gr.Accordion("Filtering tools:") filtering_tools = gr.Accordion("Filtering tools:")
with filtering_tools: with filtering_tools:
use_positive_filter = gr.Checkbox(label="Filter Duplicate Positive Prompt Content from Interrogation") use_positive_filter = gr.Checkbox(label="Filter Duplicate Positive Prompt Entries from Interrogation")
use_negative_filter = gr.Checkbox(label="Filter Duplicate Negative Prompt Content from Interrogation") use_negative_filter = gr.Checkbox(label="Filter Duplicate Negative Prompt Entries from Interrogation")
use_custom_filter = gr.Checkbox(label="Filter Custom Prompt Content from Interrogation") use_custom_filter = gr.Checkbox(label="Filter Custom Prompt Entries from Interrogation")
custom_filter = gr.Textbox(value=self.load_custom_filter(), custom_filter_group = gr.Group(visible=False)
label="Custom Filter Prompt", with custom_filter_group:
placeholder="Prompt content separated by commas. Warning ignores attention syntax, parentheses '()' and colon suffix ':XX.XX' are discarded.", custom_filter = gr.Textbox(value=self.load_custom_filter_on_start(),
show_copy_button=True label="Custom Filter Prompt",
) placeholder="Prompt content separated by commas. Warning ignores attention syntax, parentheses '()' and colon suffix ':XX.XX' are discarded.",
# Button to remove duplicates and strip strange spacing show_copy_button=True
clean_custom_filter_button = gr.Button(value="Optimize Custom Filter") )
# Button to load/save custom filter from file # Button to remove duplicates and strip strange spacing
with gr.Row(): clean_custom_filter_button = gr.Button(value="Optimize Custom Filter")
load_custom_filter_button = gr.Button(value="Load Custom Filter") # Button to load/save custom filter from file
save_confirmation_button = gr.Button(value="Save Custom Filter")
save_confirmation_row = gr.Accordion("Are You Sure You Want to Save?", visible=False)
with save_confirmation_row:
with gr.Row(): with gr.Row():
cancel_save_button = gr.Button(value="Cancel") load_custom_filter_button = gr.Button(value="Load Custom Filter")
save_custom_filter_button = gr.Button(value="Save", variant="stop") save_custom_filter_button = gr.Button(value="Save Custom Filter")
save_confirmation_custom_filter = gr.Accordion("Are You Sure You Want to Save?", visible=False)
with save_confirmation_custom_filter:
with gr.Row():
cancel_save_custom_filter_button = gr.Button(value="Cancel")
confirm_save_custom_filter_button = gr.Button(value="Save", variant="stop")
# Find and Replace
use_custom_replace = gr.Checkbox(label="Find & Replace User Defined Pairs in the Interrogation")
custom_replace_group = gr.Group(visible=False)
with custom_replace_group:
with gr.Row():
custom_replace_find = gr.Textbox(
value=self.load_custom_replace_on_start()[0],
label="Find:",
placeholder="Enter phrases to replace, separated by commas",
show_copy_button=True
)
custom_replace_replacements = gr.Textbox(
value=self.load_custom_replace_on_start()[1],
label="Replace:",
placeholder="Enter replacement phrases, separated by commas",
show_copy_button=True
)
with gr.Row():
parsed_pairs = gr.Textbox(
label="Parsed Pairs",
placeholder="Parsed pairs will be shown here",
interactive=False
)
update_parsed_pairs_button = gr.Button("🔄", elem_classes="tool")
with gr.Row():
load_custom_replace_button = gr.Button("Load Custom Replace")
save_custom_replace_button = gr.Button("Save Custom Replace")
save_confirmation_custom_replace = gr.Accordion("Are You Sure You Want to Save?", visible=False)
with save_confirmation_custom_replace:
with gr.Row():
cancel_save_custom_replace_button = gr.Button(value="Cancel")
confirm_save_custom_replace_button = gr.Button(value="Save", variant="stop")
experimental_tools = gr.Accordion("Experamental tools:", open=False) experimental_tools = gr.Accordion("Experamental tools:", open=False)
with experimental_tools: with experimental_tools:
@ -392,22 +498,31 @@ class Script(scripts.ScriptBuiltinUI):
wd_append_ratings.change(fn=self.update_slider_visibility, inputs=[wd_append_ratings], outputs=[wd_ratings]) wd_append_ratings.change(fn=self.update_slider_visibility, inputs=[wd_append_ratings], outputs=[wd_ratings])
clean_custom_filter_button.click(self.clean_string, inputs=custom_filter, outputs=custom_filter) clean_custom_filter_button.click(self.clean_string, inputs=custom_filter, outputs=custom_filter)
load_custom_filter_button.click(self.load_custom_filter, inputs=None, outputs=custom_filter) load_custom_filter_button.click(self.load_custom_filter, inputs=None, outputs=custom_filter)
save_confirmation_button.click(self.update_save_confirmation_row_true, inputs=None, outputs=[save_confirmation_row]) custom_replace_find.change(fn=self.update_parsed_pairs, inputs=[custom_replace_find, custom_replace_replacements], outputs=[parsed_pairs])
cancel_save_button.click(self.update_save_confirmation_row_false, inputs=None, outputs=[save_confirmation_row]) custom_replace_replacements.change(fn=self.update_parsed_pairs, inputs=[custom_replace_find, custom_replace_replacements], outputs=[parsed_pairs])
save_custom_filter_button.click(self.save_custom_filter, inputs=custom_filter, outputs=[save_confirmation_row]) update_parsed_pairs_button.click(fn=self.update_parsed_pairs, inputs=[custom_replace_find, custom_replace_replacements], outputs=[parsed_pairs])
save_custom_filter_button.click(self.update_save_confirmation_row_true, inputs=None, outputs=[save_confirmation_custom_filter])
cancel_save_custom_filter_button.click(self.update_save_confirmation_row_false, inputs=None, outputs=[save_confirmation_custom_filter])
confirm_save_custom_filter_button.click(self.save_custom_filter, inputs=custom_filter, outputs=[save_confirmation_custom_filter])
load_custom_replace_button.click(fn=self.load_custom_replace, inputs=[],outputs=[custom_replace_find, custom_replace_replacements])
save_custom_replace_button.click(fn=self.update_save_confirmation_row_true, inputs=[], outputs=[save_confirmation_custom_replace])
cancel_save_custom_replace_button.click(fn=self.update_save_confirmation_row_false, inputs=[], outputs=[save_confirmation_custom_replace])
confirm_save_custom_replace_button.click(fn=self.save_custom_replace, inputs=[custom_replace_find, custom_replace_replacements], outputs=[save_confirmation_custom_replace])
refresh_models_button.click(fn=self.refresh_model_options, inputs=[], outputs=[model_selection]) refresh_models_button.click(fn=self.refresh_model_options, inputs=[], outputs=[model_selection])
use_custom_filter.change(fn=self.update_group_visibility, inputs=[use_custom_filter], outputs=[custom_filter_group])
use_custom_replace.change(fn=self.update_group_visibility, inputs=[use_custom_replace], outputs=[custom_replace_group])
ui = [ ui = [
tag_batch_enabled, model_selection, debug_mode, in_front, prompt_weight_mode, prompt_weight, reverse_mode, exaggeration_mode, prompt_output, use_positive_filter, use_negative_filter, tag_batch_enabled, model_selection, debug_mode, in_front, prompt_weight_mode, prompt_weight, reverse_mode, exaggeration_mode, prompt_output, use_positive_filter, use_negative_filter,
use_custom_filter, custom_filter, clip_ext_model, clip_ext_mode, wd_ext_model, wd_threshold, wd_underscore_fix, wd_append_ratings, wd_ratings, unload_clip_models_afterwords, unload_wd_models_afterwords, use_custom_filter, custom_filter, use_custom_replace, custom_replace_find, custom_replace_replacements, clip_ext_model, clip_ext_mode, wd_ext_model, wd_threshold, wd_underscore_fix, wd_append_ratings, wd_ratings,
no_puncuation_mode unload_clip_models_afterwords, unload_wd_models_afterwords, no_puncuation_mode
] ]
return ui return ui
def process_batch( def process_batch(
self, p, tag_batch_enabled, model_selection, debug_mode, in_front, prompt_weight_mode, prompt_weight, reverse_mode, exaggeration_mode, prompt_output, use_positive_filter, use_negative_filter, self, p, tag_batch_enabled, model_selection, debug_mode, in_front, prompt_weight_mode, prompt_weight, reverse_mode, exaggeration_mode, prompt_output, use_positive_filter, use_negative_filter,
use_custom_filter, custom_filter, clip_ext_model, clip_ext_mode, wd_ext_model, wd_threshold, wd_underscore_fix, wd_append_ratings, wd_ratings, unload_clip_models_afterwords, unload_wd_models_afterwords, use_custom_filter, custom_filter, use_custom_replace, custom_replace_find, custom_replace_replacements, clip_ext_model, clip_ext_mode, wd_ext_model, wd_threshold, wd_underscore_fix, wd_append_ratings, wd_ratings,
no_puncuation_mode, batch_number, prompts, seeds, subseeds): unload_clip_models_afterwords, unload_wd_models_afterwords, no_puncuation_mode, batch_number, prompts, seeds, subseeds):
if not tag_batch_enabled: if not tag_batch_enabled:
return return
@ -522,6 +637,11 @@ class Script(scripts.ScriptBuiltinUI):
if not exaggeration_mode: if not exaggeration_mode:
interrogation = self.clean_string(interrogation) interrogation = self.clean_string(interrogation)
# Find and Replace user defined words in the interrogation prompt
if use_custom_replace:
replace_pairs = self.parse_replace_pairs(custom_replace_find, custom_replace_replacements)
interrogation = self.custom_replace(interrogation, replace_pairs)
# Remove duplicate prompt content from interrogator prompt # Remove duplicate prompt content from interrogator prompt
if use_positive_filter: if use_positive_filter:
interrogation = self.filter_words(interrogation, p.prompt) interrogation = self.filter_words(interrogation, p.prompt)