diff --git a/scripts/sd_tag_batch.py b/scripts/sd_tag_batch.py index 6762454..f12db06 100644 --- a/scripts/sd_tag_batch.py +++ b/scripts/sd_tag_batch.py @@ -14,6 +14,7 @@ Thanks to Smirking Kitsune. """ +# Extention List Crawler def get_extensions_list(): from modules import extensions extensions.list_extensions() @@ -28,12 +29,14 @@ def get_extensions_list(): }) return ext_list +# Extention Checker def is_interrogator_enabled(interrogator): for ext in get_extensions_list(): if ext["name"] == interrogator: return ext["enabled"] return False +# EXT Importer def import_module(module_name, file_path): spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) @@ -45,6 +48,7 @@ class Script(scripts.Script): wd_ext_utils = None clip_ext = None + # Checks for CLIP EXT to see if it is installed and enabled @classmethod def load_clip_ext_module(cls): if is_interrogator_enabled('clip-interrogator-ext'): @@ -52,6 +56,7 @@ class Script(scripts.Script): return cls.clip_ext return None + # Checks for WD EXT to see if it is installed and enabled @classmethod def load_wd_ext_module(cls): if is_interrogator_enabled('stable-diffusion-webui-wd14-tagger'): @@ -60,10 +65,12 @@ class Script(scripts.Script): return cls.wd_ext_utils return None + # Initiates extenion check at startup for CLIP EXT @classmethod def load_clip_ext_module_wrapper(cls, *args, **kwargs): return cls.load_clip_ext_module() + # Initiates extenion check at startup for WD EXT @classmethod def load_wd_ext_module_wrapper(cls, *args, **kwargs): return cls.load_wd_ext_module() @@ -91,20 +98,21 @@ class Script(scripts.Script): return gr.Dropdown.update(choices=all_options, value=updated_choices) - # Function to load CLIP models + # Function to load CLIP models list into CLIP model selector def load_clip_models(self): if self.clip_ext is not None: models = self.clip_ext.get_models() return gr.Dropdown.update(choices=models if models else None) return gr.Dropdown.update(choices=None) - # Function to load WD models + # Function to load WD models list into WD model selector def load_wd_models(self): if self.wd_ext_utils is not None: models = self.get_WD_EXT_models() return gr.Dropdown.update(choices=models if models else None) return gr.Dropdown.update(choices=None) + # Gets a list of WD models from WD EXT def get_WD_EXT_models(self): if self.wd_ext_utils is not None: try: @@ -117,18 +125,23 @@ class Script(scripts.Script): print(f"Error accessing WD Tagger: {error}") return [] + #Unloads CLIP Models + def unload_clip_models(self): + if self.clip_ext is not None: + self.clip_ext.unload() + + #Unloads WD Models def unload_wd_models(self): unloaded_models = 0 if self.wd_ext_utils is not None: for interrogator in self.wd_ext_utils.interrogators.values(): if interrogator.unload(): unloaded_models = unloaded_models + 1 - print(f"Unloaded {unloaded_models} of WD Extension's Model(s).") + print(f"Unloaded {unloaded_models} Tagger Model(s).") + - def unload_clip_models(self): - if self.clip_ext is not None: - self.clip_ext.unload() + # depending on if CLIP (EXT) is present, CLIP (EXT) could be removed from model selector def update_clip_ext_visibility(self, model_selection): is_visible = "CLIP (EXT)" in model_selection if is_visible: @@ -137,6 +150,7 @@ class Script(scripts.Script): else: return gr.Accordion.update(visible=False), gr.Dropdown.update() + # depending on if WD (EXT) is present, WD (EXT) could be removed from model selector def update_wd_ext_visibility(self, model_selection): is_visible = "WD (EXT)" in model_selection if is_visible: @@ -145,6 +159,7 @@ class Script(scripts.Script): else: return gr.Accordion.update(visible=False), gr.Dropdown.update() + # Depending on if prompt weight is enabled the slider will be dynamically visible def update_prompt_weight_visibility(self, use_weight): return gr.Slider.update(visible=use_weight) @@ -157,9 +172,16 @@ class Script(scripts.Script): def ui(self, is_img2img): with gr.Group(): model_options = ["CLIP (EXT)", "CLIP (Native)", "Deepbooru (Native)", "WD (EXT)"] - model_selection = gr.Dropdown(choices=model_options, label="Select Interrogation Model(s)", multiselect=True, value=None) + model_selection = gr.Dropdown(choices=model_options, + label="Interrogation Model(s):", + multiselect=True, + value=None, + ) - in_front = gr.Radio(choices=["Prepend to prompt", "Append to prompt"], value="Prepend to prompt", label="Interrogator result position") + in_front = gr.Radio( + choices=["Prepend to prompt", "Append to prompt"], + value="Prepend to prompt", + label="Interrogator result position") use_weight = gr.Checkbox(label="Use Interrogator Prompt Weight", value=True) prompt_weight = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Interrogator Prompt Weight", visible=True) @@ -167,19 +189,19 @@ class Script(scripts.Script): # CLIP EXT Options clip_ext_accordion = gr.Accordion("CLIP EXT Options:", open=False, visible=False) with clip_ext_accordion: - clip_ext_model = gr.Dropdown(choices=[], value='ViT-L-14/openai', label="CLIP EXT Model", multiselect=True) - clip_ext_mode = gr.Radio(choices=["best", "fast", "classic", "negative"], value='best', label="CLIP EXT Mode") - unload_clip_models_afterwords = gr.Checkbox(label="Unload CLIP Model After Use", value=True) - unload_clip_models_button = gr.Button(value="Unload CLIP Models") + clip_ext_model = gr.Dropdown(choices=[], value='ViT-L-14/openai', label="CLIP Extension Model(s):", multiselect=True) + clip_ext_mode = gr.Radio(choices=["best", "fast", "classic", "negative"], value='best', label="CLIP Extension Mode") + unload_clip_models_afterwords = gr.Checkbox(label="Unload CLIP Interrogator After Use", value=True) + unload_clip_models_button = gr.Button(value="Unload All CLIP Interrogators") # WD EXT Options wd_ext_accordion = gr.Accordion("WD EXT Options:", open=False, visible=False) with wd_ext_accordion: - wd_ext_model = gr.Dropdown(choices=[], value='wd-v1-4-moat-tagger.v2', label="WD EXT Model", multiselect=True) + wd_ext_model = gr.Dropdown(choices=[], value='wd-v1-4-moat-tagger.v2', label="WD Extension Model(s):", multiselect=True) wd_threshold = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Threshold") wd_underscore_fix = gr.Checkbox(label="Remove Underscores from Tags", value=True) - unload_wd_models_afterwords = gr.Checkbox(label="Unload WD Model After Use", value=True) - unload_wd_models_button = gr.Button(value="Unload WD Models") + unload_wd_models_afterwords = gr.Checkbox(label="Unload Tagger After Use", value=True) + unload_wd_models_button = gr.Button(value="Unload All Tagger Models") with gr.Accordion("Filtering tools:"): no_duplicates = gr.Checkbox(label="Filter Duplicate Prompt Content from Interrogation", value=False) @@ -227,6 +249,7 @@ class Script(scripts.Script): return words.strip() + # Tag filtering, removes negative tags from prompt def filter_words(self, prompt, negative): # Corrects a potential error where negative is nonetype if negative is None: @@ -244,9 +267,20 @@ class Script(scripts.Script): return filtered_prompt + # For WD Tagger, removes underscores from tags that should have spaces + def replace_underscores(self, tag): + preserve_patterns = [ + "0_0", "(o)_(o)", "+_+", "+_-", "._.", "_", "<|>_<|>", "=_=", ">_<", + "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||" + ] + if tag in preserve_patterns: + return tag + return tag.replace('_', ' ') + def run(self, p, in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_ext_model, clip_ext_mode, wd_ext_model, wd_threshold, wd_underscore_fix, unload_clip_models_afterwords, unload_wd_models_afterwords): raw_prompt = p.prompt interrogation = "" + preliminary_interrogation = "" # fix alpha channel p.init_images[0] = p.init_images[0].convert("RGB") @@ -254,13 +288,13 @@ class Script(scripts.Script): for model in model_selection: # Should add the interrogators in the order determined by the model_selection list if model == "Deepbooru (Native)": - interrogation += deepbooru.model.tag(p.init_images[0]) + ", " + preliminary_interrogation = deepbooru.model.tag(p.init_images[0]) + ", " elif model == "CLIP (Native)": - interrogation += shared.interrogator.interrogate(p.init_images[0]) + ", " + preliminary_interrogation = shared.interrogator.interrogate(p.init_images[0]) + ", " elif model == "CLIP (EXT)": if self.clip_ext is not None: for clip_model in clip_ext_model: - interrogation += self.clip_ext.image_to_prompt(p.init_images[0], clip_ext_mode, clip_model) + ", " + preliminary_interrogation = self.clip_ext.image_to_prompt(p.init_images[0], clip_ext_mode, clip_model) + ", " if unload_clip_models_afterwords: self.clip_ext.unload() elif model == "WD (EXT)": @@ -269,14 +303,15 @@ class Script(scripts.Script): rating, tags = self.wd_ext_utils.interrogators[wd_model].interrogate(p.init_images[0]) tags_list = [tag for tag, conf in tags.items() if conf > wd_threshold] if wd_underscore_fix: - tags_spaced = [tag.replace('_', ' ') for tag in tags_list] - interrogation += ", ".join(tags_spaced) + tags_spaced = [self.replace_underscores(tag) for tag in tags_list] + preliminary_interrogation = ", ".join(tags_spaced) + ", " else: - interrogation += ", ".join(tags_list) - interrogation += ", " + preliminary_interrogation += ", ".join(tags_list) + ", " if unload_wd_models_afterwords: self.wd_ext_utils.interrogators[wd_model].unload() - + + # Filter prevents overexaggeration of tags due to interrogation models having similar results + interrogation += self.filter_words(preliminary_interrogation, interrogation) # Remove duplicate prompt content from interrogator prompt if no_duplicates: @@ -315,5 +350,6 @@ class Script(scripts.Script): return processed +#Startup Callbacks script_callbacks.on_app_started(Script.load_clip_ext_module_wrapper) script_callbacks.on_app_started(Script.load_wd_ext_module_wrapper)