Direct Extention Interaction

Major overhaul to this patch. I was unsatisfied with the performance and restraints that working through the API caused. So API has been removed in favor of direct extension interactions.
pull/10/head
Smirking Kitsune 2024-06-28 06:45:39 -07:00 committed by GitHub
parent b7f1440367
commit 1f8d537350
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 158 additions and 231 deletions

View File

@ -4,9 +4,10 @@ import os
import requests
from io import BytesIO
import base64
from modules import scripts, deepbooru, script_callbacks
from modules import scripts, deepbooru, script_callbacks, shared
from modules.processing import process_images
import modules.shared as shared
import sys
import importlib.util
NAME = "Img2img batch interrogator"
@ -17,35 +18,59 @@ Thanks to Smirking Kitsune.
"""
def get_extensions_list():
from modules import extensions
extensions.list_extensions()
ext_list = []
for ext in extensions.extensions:
ext: extensions.Extension
ext.read_info_from_repo()
if ext.remote is not None:
ext_list.append({
"name": ext.name,
"enabled":ext.enabled
})
return ext_list
def is_interrogator_enabled(interrogator):
for ext in get_extensions_list():
if ext["name"] == interrogator:
return ext["enabled"]
return False
def import_module(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
class Script(scripts.Script):
server_address = None
wd_ext_utils = None
clip_ext = None
@classmethod
def set_server_address(cls, demo, app, *args, **kwargs):
cls.server_address = demo.local_url
print(f"Server address set to: {cls.server_address}")
def load_clip_ext_module(cls):
if is_interrogator_enabled('clip-interrogator-ext'):
cls.clip_ext = import_module("clip-interrogator-ext", "extensions/clip-interrogator-ext/scripts/clip_interrogator_ext.py")
return cls.clip_ext
return None
@classmethod
def load_wd_ext_module(cls):
if is_interrogator_enabled('stable-diffusion-webui-wd14-tagger'):
sys.path.append('extensions/stable-diffusion-webui-wd14-tagger')
cls.wd_ext_utils = import_module("utils", "extensions/stable-diffusion-webui-wd14-tagger/tagger/utils.py")
return cls.wd_ext_utils
return None
@classmethod
def get_server_address(cls):
if cls.server_address:
return cls.server_address
# Fallback to the brute force method if server_address is not set
# Initial testing indicates that fallback method will never be used...
print("Server address not set. Falling back to brute force method.")
# Fallback is highly inefficient and in some cases slow (especially if expected port is far from default)
ports = range(7860, 7960) # Gradio will increment port 100 times if default and subsequent desired ports are unavailable.
for port in ports:
url = f"http://127.0.0.1:{port}/"
try:
response = requests.get(f"{url}internal/ping", timeout=5)
if response.status_code == 200:
return url
except requests.RequestException as error:
print(f"API not available on port {port}: {error}")
def load_clip_ext_module_wrapper(cls, *args, **kwargs):
return cls.load_clip_ext_module()
print("API not found")
return None
@classmethod
def load_wd_ext_module_wrapper(cls, *args, **kwargs):
return cls.load_wd_ext_module()
def title(self):
return NAME
@ -57,101 +82,105 @@ class Script(scripts.Script):
def b_clicked(o):
return gr.Button.update(interactive=True)
def is_interrogator_enabled(self, interrogator):
api_address = f"{self.get_server_address()}sdapi/v1/extensions"
headers = {'accept': 'application/json'}
try:
response = requests.get(api_address, headers=headers)
response.raise_for_status()
extensions = response.json()
for extension in extensions:
if extension['name'] == interrogator:
return extension['enabled']
return False
except requests.RequestException:
print(f"Error occurred while fetching extension: {interrogator}")
return False
# Removes unsupported interrogators, support may vary depending on client
def update_model_choices(self, current_choices):
all_options = ["CLIP (API)", "CLIP (Native)", "Deepbooru (Native)", "WD (API)"]
if not self.is_interrogator_enabled('clip-interrogator-ext'):
all_options.remove("CLIP (API)")
if not self.is_interrogator_enabled('stable-diffusion-webui-wd14-tagger'):
all_options.remove("WD (API)")
# Keep the current selections if they're still valid
all_options = ["CLIP (EXT)", "CLIP (Native)", "Deepbooru (Native)", "WD (EXT)"]
if not is_interrogator_enabled('clip-interrogator-ext'):
all_options.remove("CLIP (EXT)")
if not is_interrogator_enabled('stable-diffusion-webui-wd14-tagger'):
all_options.remove("WD (EXT)")
updated_choices = [choice for choice in current_choices if choice in all_options]
return gr.Dropdown.update(choices=all_options, value=updated_choices)
# Function to load CLIP models
def load_clip_models(self):
models = self.get_clip_API_models()
return gr.Dropdown.update(choices=models if models else None)
if self.clip_ext is not None:
models = self.clip_ext.get_models()
return gr.Dropdown.update(choices=models if models else None)
return gr.Dropdown.update(choices=None)
# Function to load WD models
def load_wd_models(self):
models = self.get_WD_API_models()
return gr.Dropdown.update(choices=models if models else None)
if self.wd_ext_utils is not None:
models = self.get_WD_EXT_models()
return gr.Dropdown.update(choices=models if models else None)
return gr.Dropdown.update(choices=None)
def get_WD_EXT_models(self):
if self.wd_ext_utils is not None:
try:
self.wd_ext_utils.refresh_interrogators()
models = list(self.wd_ext_utils.interrogators.keys())
if not models:
raise Exception("No WD Tagger models found.")
return models
except Exception as error:
print(f"Error accessing WD Tagger: {error}")
return []
def unload_wd_models(self):
if self.wd_ext_utils is not None:
for interrogator in self.wd_ext_utils.interrogators.values():
interrogator.unload()
def unload_clip_models(self):
if self.clip_ext is not None:
self.clip_ext.unload()
def update_clip_ext_visibility(self, model_selection):
is_visible = "CLIP (EXT)" in model_selection
if is_visible:
clip_models = self.load_clip_models()
return gr.Accordion.update(visible=True), clip_models
else:
return gr.Accordion.update(visible=False), gr.Dropdown.update()
def update_wd_ext_visibility(self, model_selection):
is_visible = "WD (EXT)" in model_selection
if is_visible:
wd_models = self.load_wd_models()
return gr.Accordion.update(visible=True), wd_models
else:
return gr.Accordion.update(visible=False), gr.Dropdown.update()
def update_prompt_weight_visibility(self, use_weight):
return gr.Slider.update(visible=use_weight)
# Function to load custom filter from file
def load_custom_filter(self, custom_filter):
with open("extensions/sd-Img2img-batch-interrogator/custom_filter.txt", "r") as file:
custom_filter = file.read()
return custom_filter
def ui(self, is_img2img):
with gr.Group():
model_options = ["CLIP (API)", "CLIP (Native)", "Deepbooru (Native)", "WD (API)"]
model_options = ["CLIP (EXT)", "CLIP (Native)", "Deepbooru (Native)", "WD (EXT)"]
model_selection = gr.Dropdown(choices=model_options, label="Select Interrogation Model(s)", multiselect=True, value=None)
in_front = gr.Radio(
choices=["Prepend to prompt", "Append to prompt"],
value="Prepend to prompt",
label="Interrogator result position"
)
def update_prompt_weight_visibility(use_weight):
return gr.Slider.update(visible=use_weight)
in_front = gr.Radio(choices=["Prepend to prompt", "Append to prompt"], value="Prepend to prompt", label="Interrogator result position")
use_weight = gr.Checkbox(label="Use Interrogator Prompt Weight", value=True)
prompt_weight = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Interrogator Prompt Weight", visible=True)
# CLIP API Options
def update_clip_api_visibility(model_selection):
is_visible = "CLIP (API)" in model_selection
if is_visible:
clip_models = self.load_clip_models()
return gr.Accordion.update(visible=True), clip_models
else:
return gr.Accordion.update(visible=False), gr.Dropdown.update()
clip_api_accordion = gr.Accordion("CLIP API Options:", open=False, visible=False)
with clip_api_accordion:
clip_api_model = gr.Dropdown(choices=[], value='ViT-L-14/openai', label="CLIP API Model")
clip_api_mode = gr.Radio(choices=["best", "fast", "classic", "negative"], value='best', label="CLIP API Mode")
# CLIP EXT Options
clip_ext_accordion = gr.Accordion("CLIP EXT Options:", open=False, visible=False)
with clip_ext_accordion:
clip_ext_model = gr.Dropdown(choices=[], value='ViT-L-14/openai', label="CLIP EXT Model", multiselect=True)
clip_ext_mode = gr.Radio(choices=["best", "fast", "classic", "negative"], value='best', label="CLIP EXT Mode")
unload_clip_models_afterwords = gr.Checkbox(label="Unload CLIP Model After Use", value=True)
unload_clip_models_button = gr.Button(value="Unload CLIP Models")
# WD API Options
def update_wd_api_visibility(model_selection):
is_visible = "WD (API)" in model_selection
if is_visible:
wd_models = self.load_wd_models()
return gr.Accordion.update(visible=True), wd_models
else:
return gr.Accordion.update(visible=False), gr.Dropdown.update()
wd_api_accordion = gr.Accordion("WD API Options:", open=False, visible=False)
with wd_api_accordion:
wd_api_model = gr.Dropdown(choices=[], value='wd-v1-4-moat-tagger.v2', label="WD API Model")
# WD EXT Options
wd_ext_accordion = gr.Accordion("WD EXT Options:", open=False, visible=False)
with wd_ext_accordion:
wd_ext_model = gr.Dropdown(choices=[], value='wd-swinv2-tagger.v3', label="WD EXT Model", multiselect=True)
wd_threshold = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Threshold")
wd_underscore_fix = gr.Checkbox(label="Remove Underscores from Tags", value=True)
unload_wd_models_afterwords = gr.Checkbox(label="Unload WD Model After Use", value=True)
unload_wd_models_button = gr.Button(value="Unload WD Models")
# Function to load custom filter from file
def load_custom_filter(custom_filter):
with open("extensions/sd-Img2img-batch-interrogator/custom_filter.txt", "r") as file:
custom_filter = file.read()
return custom_filter
with gr.Accordion("Filtering tools:"):
no_duplicates = gr.Checkbox(label="Filter Duplicate Prompt Content from Interrogation", value=False)
@ -167,14 +196,15 @@ class Script(scripts.Script):
# Listeners
model_selection.select(fn=self.update_model_choices, inputs=[model_selection], outputs=[model_selection])
model_selection.change(fn=update_clip_api_visibility, inputs=[model_selection], outputs=[clip_api_accordion, clip_api_model])
model_selection.change(fn=update_wd_api_visibility, inputs=[model_selection], outputs=[wd_api_accordion, wd_api_model])
load_custom_filter_button.click(load_custom_filter, inputs=custom_filter, outputs=custom_filter)
unload_wd_models_button.click(self.post_wd_api_unload, inputs=None, outputs=None)
use_weight.change(fn=update_prompt_weight_visibility, inputs=[use_weight], outputs=[prompt_weight])
model_selection.change(fn=self.update_clip_ext_visibility, inputs=[model_selection], outputs=[clip_ext_accordion, clip_ext_model])
model_selection.change(fn=self.update_wd_ext_visibility, inputs=[model_selection], outputs=[wd_ext_accordion, wd_ext_model])
load_custom_filter_button.click(self.load_custom_filter, inputs=custom_filter, outputs=custom_filter)
unload_clip_models_button.click(self.unload_clip_models, inputs=None, outputs=None)
unload_wd_models_button.click(self.unload_wd_models, inputs=None, outputs=None)
use_weight.change(fn=self.update_prompt_weight_visibility, inputs=[use_weight], outputs=[prompt_weight])
return [in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_api_model, clip_api_mode, wd_api_model, wd_threshold, wd_underscore_fix, unload_wd_models_afterwords]
return [in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_ext_model, clip_ext_mode, wd_ext_model, wd_threshold, wd_underscore_fix, unload_clip_models_afterwords, unload_wd_models_afterwords]
# Required to parse information from a string that is between () or has :##.## suffix
def remove_attention(self, words):
@ -215,28 +245,39 @@ class Script(scripts.Script):
return filtered_prompt
def run(self, p, in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_api_model, clip_api_mode, wd_api_model, wd_threshold, wd_underscore_fix, unload_wd_models_afterwords):
def run(self, p, in_front, prompt_weight, model_selection, use_weight, no_duplicates, use_negatives, use_custom_filter, custom_filter, clip_ext_model, clip_ext_mode, wd_ext_model, wd_threshold, wd_underscore_fix, unload_clip_models_afterwords, unload_wd_models_afterwords):
raw_prompt = p.prompt
interrogator = ""
# fix alpha channel
p.init_images[0] = p.init_images[0].convert("RGB")
first = True # Two interrogator concatenation correction boolean
for model in model_selection:
# This prevents two interrogators from being incorrectly concatenated
if first == False:
interrogator += ", "
first = False
# Should add the interrogators in the order determined by the model_selection list
if model == "Deepbooru (Native)":
interrogator += deepbooru.model.tag(p.init_images[0])
interrogator += deepbooru.model.tag(p.init_images[0]) + ", "
elif model == "CLIP (Native)":
interrogator += shared.interrogator.interrogate(p.init_images[0])
elif model == "CLIP (API)":
interrogator += self.post_clip_api_prompt(p.init_images[0], clip_api_model, clip_api_mode)
elif model == "WD (API)":
interrogator += self.post_wd_api_tagger(p.init_images[0], wd_api_model, wd_threshold, wd_underscore_fix)
interrogator += shared.interrogator.interrogate(p.init_images[0]) + ", "
elif model == "CLIP (EXT)":
if self.clip_ext is not None:
for clip_model in clip_ext_model:
interrogator += self.clip_ext.image_to_prompt(p.init_images[0], clip_ext_mode, clip_model) + ", "
if unload_clip_models_afterwords:
self.clip_ext.unload()
elif model == "WD (EXT)":
if self.wd_ext_utils is not None:
for wd_model in wd_ext_model:
interrogator = self.wd_ext_utils.interrogators[wd_model]
rating, tags = interrogator.interrogate(p.init_images[0])
tags_list = [tag for tag, conf in tags.items() if conf > wd_threshold]
if wd_underscore_fix:
tags_spaced = [tag.replace('_', ' ') for tag in tags_list]
interrogator += ", ".join(tags_spaced) + ", "
else:
interrogator += ", ".join(tags_list) + ", "
if unload_wd_models_afterwords:
self.wd_ext_utils.interrogators[wd_ext_model].unload()
# Remove duplicate prompt content from interrogator prompt
if no_duplicates:
@ -266,9 +307,6 @@ class Script(scripts.Script):
else:
p.prompt = f"{interrogator}, {p.prompt}"
if unload_wd_models_afterwords and "WD (API)" in model_selection:
self.post_wd_api_unload()
print(f"Prompt: {p.prompt}")
processed = process_images(p)
@ -278,116 +316,5 @@ class Script(scripts.Script):
return processed
# CLIP API Model Identification
def get_clip_API_models(self):
# Ensure CLIP Interrogator is present and accessible
try:
api_address = f"{self.get_server_address()}interrogator/models"
response = requests.get(api_address)
response.raise_for_status()
models = response.json()
if not models:
raise Exception("No CLIP Interrogator models found.")
except Exception as error:
print(f"Error accessing CLIP Interrogator API: {error}")
return []
return models
# WD API Model Identification
def get_WD_API_models(self):
# Ensure WD Interrogator is present and accessible
try:
api_address = f"{self.get_server_address()}tagger/v1/interrogators"
response = requests.get(api_address)
response.raise_for_status()
models = response.json()["models"]
if not models:
raise Exception("No WD Tagger models found.")
except Exception as error:
print(f"Error accessing WD Tagger API: {error}")
return []
return models
# CLIP API Prompt Generator
def post_clip_api_prompt(self, image, model_name, mode):
print("Starting CLIP Interrogator API interaction...")
# Ensure the model and mode are provided
if not model_name:
print("CLIP API model is required.")
return ""
# Encode the image to base64
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Get the prompt from the CLIP API
try:
payload = {
"image": img_str,
"mode": mode,
"clip_model_name": model_name
}
api_address = f"{self.get_server_address()}interrogator/prompt"
response = requests.post(api_address, json=payload)
response.raise_for_status()
result = response.json()
return result.get("prompt", "")
except Exception as error:
print(f"Error generating prompt with CLIP API: {error}")
return ""
# WD API Interrogation Tagger
def post_wd_api_tagger(self, image, model_name, threshold, underscore):
print("Starting WD Tagger API interaction...")
# Ensure the model and mode are provided
if not model_name:
print("WD API model is required.")
return ""
# Encode the image to base64
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Get the prompt from the WD API
try:
payload = {
"image": img_str,
"model": model_name,
"threshold": threshold,
"queue": "",
"name_in_queue": ""
}
api_address = f"{self.get_server_address()}tagger/v1/interrogate"
# WARNING: Removing `timeout` could result in a frozen client if the queue_lock is locked. If you need more time add more time, do not remove timeout or risk DEADLOCK.
# Note: If WD Tagger did not load a model, it is likely that WD Tagger specifically queue_lock (FIFOLock) is concerned with your system's threading and thinks running could cause processes starvation...
# Note: It would be advisable to download models in the WD tab due to the timeout
response = requests.post(api_address, json=payload, timeout=120)
response.raise_for_status()
result = response.json()
tags_list = result.get("caption", {}).get("tag", [])
if underscore:
tags_spaced = [tag.replace('_', ' ') for tag in tags_list]
tags_string = ", ".join(tags_spaced)
else:
tags_string = ", ".join(tags_list)
return tags_string
except Exception as error:
print(f"Error generating prompt with WD API: {error}")
return ""
# WD API Model Unloader
def post_wd_api_unload(self):
try:
api_address = f"{self.get_server_address()}tagger/v1/unload-interrogators"
response = requests.post(api_address, json='')
response.raise_for_status()
except Exception as error:
print(f"Error Unloading models with WD API: {error}")
script_callbacks.on_app_started(Script.set_server_address)
script_callbacks.on_app_started(Script.load_clip_ext_module_wrapper)
script_callbacks.on_app_started(Script.load_wd_ext_module_wrapper)