diff --git a/scripts/sd_tag_batch.py b/scripts/sd_tag_batch.py index c1835e8..958ebc7 100644 --- a/scripts/sd_tag_batch.py +++ b/scripts/sd_tag_batch.py @@ -4,9 +4,10 @@ import os import requests from io import BytesIO import base64 -from modules import scripts, deepbooru, script_callbacks +from modules import scripts, deepbooru, script_callbacks, shared from modules.processing import process_images -import modules.shared as shared +import sys +import importlib.util NAME = "Img2img batch interrogator" @@ -17,35 +18,59 @@ Thanks to Smirking Kitsune. """ +def get_extensions_list(): + from modules import extensions + extensions.list_extensions() + ext_list = [] + for ext in extensions.extensions: + ext: extensions.Extension + ext.read_info_from_repo() + if ext.remote is not None: + ext_list.append({ + "name": ext.name, + "enabled":ext.enabled + }) + return ext_list + +def is_interrogator_enabled(interrogator): + for ext in get_extensions_list(): + if ext["name"] == interrogator: + return ext["enabled"] + return False + +def import_module(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + class Script(scripts.Script): - server_address = None + wd_ext_utils = None + clip_ext = None @classmethod - def set_server_address(cls, demo, app, *args, **kwargs): - cls.server_address = demo.local_url - print(f"Server address set to: {cls.server_address}") + def load_clip_ext_module(cls): + if is_interrogator_enabled('clip-interrogator-ext'): + cls.clip_ext = import_module("clip-interrogator-ext", "extensions/clip-interrogator-ext/scripts/clip_interrogator_ext.py") + return cls.clip_ext + return None + + @classmethod + def load_wd_ext_module(cls): + if is_interrogator_enabled('stable-diffusion-webui-wd14-tagger'): + sys.path.append('extensions/stable-diffusion-webui-wd14-tagger') + cls.wd_ext_utils = import_module("utils", "extensions/stable-diffusion-webui-wd14-tagger/tagger/utils.py") + return cls.wd_ext_utils + return None @classmethod - def get_server_address(cls): - if cls.server_address: - return cls.server_address - - # Fallback to the brute force method if server_address is not set - # Initial testing indicates that fallback method will never be used... - print("Server address not set. Falling back to brute force method.") - # Fallback is highly inefficient and in some cases slow (especially if expected port is far from default) - ports = range(7860, 7960) # Gradio will increment port 100 times if default and subsequent desired ports are unavailable. - for port in ports: - url = f"http://127.0.0.1:{port}/" - try: - response = requests.get(f"{url}internal/ping", timeout=5) - if response.status_code == 200: - return url - except requests.RequestException as error: - print(f"API not available on port {port}: {error}") + def load_clip_ext_module_wrapper(cls, *args, **kwargs): + return cls.load_clip_ext_module() - print("API not found") - return None + @classmethod + def load_wd_ext_module_wrapper(cls, *args, **kwargs): + return cls.load_wd_ext_module() def title(self): return NAME @@ -57,101 +82,105 @@ class Script(scripts.Script): def b_clicked(o): return gr.Button.update(interactive=True) - def is_interrogator_enabled(self, interrogator): - api_address = f"{self.get_server_address()}sdapi/v1/extensions" - headers = {'accept': 'application/json'} - - try: - response = requests.get(api_address, headers=headers) - response.raise_for_status() - extensions = response.json() - - for extension in extensions: - if extension['name'] == interrogator: - return extension['enabled'] - - return False - except requests.RequestException: - print(f"Error occurred while fetching extension: {interrogator}") - return False - # Removes unsupported interrogators, support may vary depending on client def update_model_choices(self, current_choices): - all_options = ["CLIP (API)", "CLIP (Native)", "Deepbooru (Native)", "WD (API)"] - - if not self.is_interrogator_enabled('clip-interrogator-ext'): - all_options.remove("CLIP (API)") - if not self.is_interrogator_enabled('stable-diffusion-webui-wd14-tagger'): - all_options.remove("WD (API)") - - # Keep the current selections if they're still valid + all_options = ["CLIP (EXT)", "CLIP (Native)", "Deepbooru (Native)", "WD (EXT)"] + + if not is_interrogator_enabled('clip-interrogator-ext'): + all_options.remove("CLIP (EXT)") + if not is_interrogator_enabled('stable-diffusion-webui-wd14-tagger'): + all_options.remove("WD (EXT)") + updated_choices = [choice for choice in current_choices if choice in all_options] - + return gr.Dropdown.update(choices=all_options, value=updated_choices) # Function to load CLIP models def load_clip_models(self): - models = self.get_clip_API_models() - return gr.Dropdown.update(choices=models if models else None) + if self.clip_ext is not None: + models = self.clip_ext.get_models() + return gr.Dropdown.update(choices=models if models else None) + return gr.Dropdown.update(choices=None) # Function to load WD models def load_wd_models(self): - models = self.get_WD_API_models() - return gr.Dropdown.update(choices=models if models else None) + if self.wd_ext_utils is not None: + models = self.get_WD_EXT_models() + return gr.Dropdown.update(choices=models if models else None) + return gr.Dropdown.update(choices=None) + + def get_WD_EXT_models(self): + if self.wd_ext_utils is not None: + try: + self.wd_ext_utils.refresh_interrogators() + models = list(self.wd_ext_utils.interrogators.keys()) + if not models: + raise Exception("No WD Tagger models found.") + return models + except Exception as error: + print(f"Error accessing WD Tagger: {error}") + return [] + + def unload_wd_models(self): + if self.wd_ext_utils is not None: + for interrogator in self.wd_ext_utils.interrogators.values(): + interrogator.unload() + + def unload_clip_models(self): + if self.clip_ext is not None: + self.clip_ext.unload() + + def update_clip_ext_visibility(self, model_selection): + is_visible = "CLIP (EXT)" in model_selection + if is_visible: + clip_models = self.load_clip_models() + return gr.Accordion.update(visible=True), clip_models + else: + return gr.Accordion.update(visible=False), gr.Dropdown.update() + + def update_wd_ext_visibility(self, model_selection): + is_visible = "WD (EXT)" in model_selection + if is_visible: + wd_models = self.load_wd_models() + return gr.Accordion.update(visible=True), wd_models + else: + return gr.Accordion.update(visible=False), gr.Dropdown.update() + + def update_prompt_weight_visibility(self, use_weight): + return gr.Slider.update(visible=use_weight) + + # Function to load custom filter from file + def load_custom_filter(self, custom_filter): + with open("extensions/sd-Img2img-batch-interrogator/custom_filter.txt", "r") as file: + custom_filter = file.read() + return custom_filter def ui(self, is_img2img): with gr.Group(): - model_options = ["CLIP (API)", "CLIP (Native)", "Deepbooru (Native)", "WD (API)"] + model_options = ["CLIP (EXT)", "CLIP (Native)", "Deepbooru (Native)", "WD (EXT)"] model_selection = gr.Dropdown(choices=model_options, label="Select Interrogation Model(s)", multiselect=True, value=None) - in_front = gr.Radio( - choices=["Prepend to prompt", "Append to prompt"], - value="Prepend to prompt", - label="Interrogator result position" - ) - - def update_prompt_weight_visibility(use_weight): - return gr.Slider.update(visible=use_weight) - + in_front = gr.Radio(choices=["Prepend to prompt", "Append to prompt"], value="Prepend to prompt", label="Interrogator result position") + use_weight = gr.Checkbox(label="Use Interrogator Prompt Weight", value=True) prompt_weight = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Interrogator Prompt Weight", visible=True) - # CLIP API Options - def update_clip_api_visibility(model_selection): - is_visible = "CLIP (API)" in model_selection - if is_visible: - clip_models = self.load_clip_models() - return gr.Accordion.update(visible=True), clip_models - else: - return gr.Accordion.update(visible=False), gr.Dropdown.update() - - clip_api_accordion = gr.Accordion("CLIP API Options:", open=False, visible=False) - with clip_api_accordion: - clip_api_model = gr.Dropdown(choices=[], value='ViT-L-14/openai', label="CLIP API Model") - clip_api_mode = gr.Radio(choices=["best", "fast", "classic", "negative"], value='best', label="CLIP API Mode") + # CLIP EXT Options + clip_ext_accordion = gr.Accordion("CLIP EXT Options:", open=False, visible=False) + with clip_ext_accordion: + clip_ext_model = gr.Dropdown(choices=[], value='ViT-L-14/openai', label="CLIP EXT Model", multiselect=True) + clip_ext_mode = gr.Radio(choices=["best", "fast", "classic", "negative"], value='best', label="CLIP EXT Mode") + unload_clip_models_afterwords = gr.Checkbox(label="Unload CLIP Model After Use", value=True) + unload_clip_models_button = gr.Button(value="Unload CLIP Models") - # WD API Options - def update_wd_api_visibility(model_selection): - is_visible = "WD (API)" in model_selection - if is_visible: - wd_models = self.load_wd_models() - return gr.Accordion.update(visible=True), wd_models - else: - return gr.Accordion.update(visible=False), gr.Dropdown.update() - - wd_api_accordion = gr.Accordion("WD API Options:", open=False, visible=False) - with wd_api_accordion: - wd_api_model = gr.Dropdown(choices=[], value='wd-v1-4-moat-tagger.v2', label="WD API Model") + # WD EXT Options + wd_ext_accordion = gr.Accordion("WD EXT Options:", open=False, visible=False) + with wd_ext_accordion: + wd_ext_model = gr.Dropdown(choices=[], value='wd-swinv2-tagger.v3', label="WD EXT Model", multiselect=True) wd_threshold = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Threshold") wd_underscore_fix = gr.Checkbox(label="Remove Underscores from Tags", value=True) unload_wd_models_afterwords = gr.Checkbox(label="Unload WD Model After Use", value=True) unload_wd_models_button = gr.Button(value="Unload WD Models") - - # Function to load custom filter from file - def load_custom_filter(custom_filter): - with open("extensions/sd-Img2img-batch-interrogator/custom_filter.txt", "r") as file: - custom_filter = file.read() - return custom_filter with gr.Accordion("Filtering tools:"): no_duplicates = gr.Checkbox(label="Filter Duplicate Prompt Content from Interrogation", value=False) @@ -167,14 +196,15 @@ class Script(scripts.Script): # Listeners model_selection.select(fn=self.update_model_choices, inputs=[model_selection], outputs=[model_selection]) - model_selection.change(fn=update_clip_api_visibility, inputs=[model_selection], outputs=[clip_api_accordion, clip_api_model]) - model_selection.change(fn=update_wd_api_visibility, inputs=[model_selection], outputs=[wd_api_accordion, wd_api_model]) - load_custom_filter_button.click(load_custom_filter, inputs=custom_filter, outputs=custom_filter) - unload_wd_models_button.click(self.post_wd_api_unload, inputs=None, outputs=None) - use_weight.change(fn=update_prompt_weight_visibility, inputs=[use_weight], outputs=[prompt_weight]) + model_selection.change(fn=self.update_clip_ext_visibility, inputs=[model_selection], outputs=[clip_ext_accordion, clip_ext_model]) + model_selection.change(fn=self.update_wd_ext_visibility, inputs=[model_selection], outputs=[wd_ext_accordion, wd_ext_model]) + load_custom_filter_button.click(self.load_custom_filter, inputs=custom_filter, outputs=custom_filter) + unload_clip_models_button.click(self.unload_clip_models, inputs=None, outputs=None) + unload_wd_models_button.click(self.unload_wd_models, inputs=None, outputs=None) + use_weight.change(fn=self.update_prompt_weight_visibility, inputs=[use_weight], outputs=[prompt_weight]) - return [in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_api_model, clip_api_mode, wd_api_model, wd_threshold, wd_underscore_fix, unload_wd_models_afterwords] + return [in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_ext_model, clip_ext_mode, wd_ext_model, wd_threshold, wd_underscore_fix, unload_clip_models_afterwords, unload_wd_models_afterwords] # Required to parse information from a string that is between () or has :##.## suffix def remove_attention(self, words): @@ -215,28 +245,39 @@ class Script(scripts.Script): return filtered_prompt - def run(self, p, in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_api_model, clip_api_mode, wd_api_model, wd_threshold, wd_underscore_fix, unload_wd_models_afterwords): + def run(self, p, in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_ext_model, clip_ext_mode, wd_ext_model, wd_threshold, wd_underscore_fix, unload_clip_models_afterwords, unload_wd_models_afterwords): raw_prompt = p.prompt interrogator = "" # fix alpha channel p.init_images[0] = p.init_images[0].convert("RGB") - first = True # Two interrogator concatenation correction boolean for model in model_selection: - # This prevents two interrogators from being incorrectly concatenated - if first == False: - interrogator += ", " - first = False # Should add the interrogators in the order determined by the model_selection list if model == "Deepbooru (Native)": - interrogator += deepbooru.model.tag(p.init_images[0]) + interrogator += deepbooru.model.tag(p.init_images[0]) + ", " elif model == "CLIP (Native)": - interrogator += shared.interrogator.interrogate(p.init_images[0]) - elif model == "CLIP (API)": - interrogator += self.post_clip_api_prompt(p.init_images[0], clip_api_model, clip_api_mode) - elif model == "WD (API)": - interrogator += self.post_wd_api_tagger(p.init_images[0], wd_api_model, wd_threshold, wd_underscore_fix) + interrogator += shared.interrogator.interrogate(p.init_images[0]) + ", " + elif model == "CLIP (EXT)": + if self.clip_ext is not None: + for clip_model in clip_ext_model: + interrogator += self.clip_ext.image_to_prompt(p.init_images[0], clip_ext_mode, clip_model) + ", " + if unload_clip_models_afterwords: + self.clip_ext.unload() + elif model == "WD (EXT)": + if self.wd_ext_utils is not None: + for wd_model in wd_ext_model: + interrogator = self.wd_ext_utils.interrogators[wd_model] + rating, tags = interrogator.interrogate(p.init_images[0]) + tags_list = [tag for tag, conf in tags.items() if conf > wd_threshold] + if wd_underscore_fix: + tags_spaced = [tag.replace('_', ' ') for tag in tags_list] + interrogator += ", ".join(tags_spaced) + ", " + else: + interrogator += ", ".join(tags_list) + ", " + if unload_wd_models_afterwords: + self.wd_ext_utils.interrogators[wd_ext_model].unload() + # Remove duplicate prompt content from interrogator prompt if no_duplicates: @@ -266,9 +307,6 @@ class Script(scripts.Script): else: p.prompt = f"{interrogator}, {p.prompt}" - if unload_wd_models_afterwords and "WD (API)" in model_selection: - self.post_wd_api_unload() - print(f"Prompt: {p.prompt}") processed = process_images(p) @@ -278,116 +316,5 @@ class Script(scripts.Script): return processed - # CLIP API Model Identification - def get_clip_API_models(self): - # Ensure CLIP Interrogator is present and accessible - try: - api_address = f"{self.get_server_address()}interrogator/models" - response = requests.get(api_address) - response.raise_for_status() - models = response.json() - if not models: - raise Exception("No CLIP Interrogator models found.") - except Exception as error: - print(f"Error accessing CLIP Interrogator API: {error}") - return [] - return models - - # WD API Model Identification - def get_WD_API_models(self): - # Ensure WD Interrogator is present and accessible - try: - api_address = f"{self.get_server_address()}tagger/v1/interrogators" - response = requests.get(api_address) - response.raise_for_status() - models = response.json()["models"] - if not models: - raise Exception("No WD Tagger models found.") - except Exception as error: - print(f"Error accessing WD Tagger API: {error}") - return [] - return models - - # CLIP API Prompt Generator - def post_clip_api_prompt(self, image, model_name, mode): - print("Starting CLIP Interrogator API interaction...") - # Ensure the model and mode are provided - if not model_name: - print("CLIP API model is required.") - return "" - - # Encode the image to base64 - buffered = BytesIO() - image.save(buffered, format="PNG") - img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") - - # Get the prompt from the CLIP API - try: - payload = { - "image": img_str, - "mode": mode, - "clip_model_name": model_name - } - api_address = f"{self.get_server_address()}interrogator/prompt" - response = requests.post(api_address, json=payload) - response.raise_for_status() - result = response.json() - return result.get("prompt", "") - except Exception as error: - print(f"Error generating prompt with CLIP API: {error}") - return "" - - # WD API Interrogation Tagger - def post_wd_api_tagger(self, image, model_name, threshold, underscore): - print("Starting WD Tagger API interaction...") - # Ensure the model and mode are provided - if not model_name: - print("WD API model is required.") - return "" - - # Encode the image to base64 - buffered = BytesIO() - image.save(buffered, format="PNG") - img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") - - # Get the prompt from the WD API - try: - payload = { - "image": img_str, - "model": model_name, - "threshold": threshold, - "queue": "", - "name_in_queue": "" - } - api_address = f"{self.get_server_address()}tagger/v1/interrogate" - # WARNING: Removing `timeout` could result in a frozen client if the queue_lock is locked. If you need more time add more time, do not remove timeout or risk DEADLOCK. - # Note: If WD Tagger did not load a model, it is likely that WD Tagger specifically queue_lock (FIFOLock) is concerned with your system's threading and thinks running could cause processes starvation... - # Note: It would be advisable to download models in the WD tab due to the timeout - response = requests.post(api_address, json=payload, timeout=120) - response.raise_for_status() - result = response.json() - - tags_list = result.get("caption", {}).get("tag", []) - - if underscore: - tags_spaced = [tag.replace('_', ' ') for tag in tags_list] - tags_string = ", ".join(tags_spaced) - else: - tags_string = ", ".join(tags_list) - - return tags_string - except Exception as error: - print(f"Error generating prompt with WD API: {error}") - return "" - - # WD API Model Unloader - def post_wd_api_unload(self): - try: - api_address = f"{self.get_server_address()}tagger/v1/unload-interrogators" - response = requests.post(api_address, json='') - response.raise_for_status() - - except Exception as error: - print(f"Error Unloading models with WD API: {error}") - -script_callbacks.on_app_started(Script.set_server_address) +script_callbacks.on_app_started(Script.load_clip_ext_module_wrapper) +script_callbacks.on_app_started(Script.load_wd_ext_module_wrapper)