277 lines
9.6 KiB
Python
277 lines
9.6 KiB
Python
import gradio as gr
|
|
|
|
from modules import scripts, script_callbacks
|
|
from modules.shared import opts, OptionInfo
|
|
|
|
from urllib.request import urlopen
|
|
from hashlib import md5
|
|
from base64 import b64encode
|
|
|
|
import re
|
|
import requests
|
|
|
|
NAME = "e621 Prompt"
|
|
|
|
# List of tags categories. Shared between settings and UI
|
|
tags_categories_options = ["artist", "character", "species", "copyright", "general", "lore", "meta", "rating", "invalid"]
|
|
default_tags_categories = ["artist", "character", "species", "general"]
|
|
|
|
# Conditionally replaces underscores
|
|
def replace_underscores(value):
|
|
if opts.e621_prompt_replace_underscores:
|
|
return value.replace("_", " ")
|
|
|
|
return value
|
|
|
|
# Escapes some prompt-specific special characters
|
|
def escape_special_characters(value):
|
|
return value.replace("(", "\(").replace(")", "\)")
|
|
|
|
# Converts string of comma-separated values into set
|
|
def comma_separated_string_to_list(string):
|
|
return list(filter(None, [s.strip() for s in string.split(",")]))
|
|
|
|
# Returns set of excluded tags
|
|
def excluded_tags():
|
|
return comma_separated_string_to_list(opts.e621_prompt_excluded_tags)
|
|
|
|
# Returns set of appended tags, replacing underscores if needed
|
|
def appended_tags():
|
|
tags = comma_separated_string_to_list(opts.e621_prompt_appended_tags)
|
|
|
|
if opts.e621_prompt_replace_underscores_in_appended:
|
|
return [replace_underscores(tag) for tag in tags]
|
|
|
|
return tags
|
|
|
|
class Script(scripts.Script):
|
|
def title(self):
|
|
return NAME
|
|
|
|
def show(self, _is_img2img):
|
|
return scripts.AlwaysVisible
|
|
|
|
# Wrapper around requests.request that sets required headers
|
|
def make_request(self, headers={}, **kwargs):
|
|
# Setup headers
|
|
req_headers = {
|
|
"User-Agent": opts.e621_prompt_user_agent
|
|
}
|
|
|
|
# Setup auth
|
|
if opts.e621_prompt_username and opts.e621_prompt_api_key:
|
|
req_headers["Authorization"] = b64encode(f"{opts.e621_prompt_username}:{opts.e621_prompt_api_key}".encode())
|
|
|
|
# Mix our headers with additional ones
|
|
req_headers.update(headers)
|
|
|
|
# Setup proxy
|
|
proxies = None
|
|
if opts.e621_prompt_use_proxy and opts.e621_prompt_proxy_url:
|
|
proxies = {
|
|
"https": opts.e621_prompt_proxy_url
|
|
}
|
|
|
|
response = requests.request(**kwargs, headers=req_headers, proxies=proxies)
|
|
|
|
response.raise_for_status()
|
|
|
|
return response.json()
|
|
|
|
# Parse source and extract md5 hash or id
|
|
def normalize_source(self, source):
|
|
if source is None:
|
|
return ("error", "Enter post info")
|
|
|
|
found_hash = re.search(r"([a-fA-F\d]{32})", source)
|
|
|
|
if found_hash:
|
|
return ("md5", found_hash.group(0))
|
|
|
|
found_post_url = re.search(r"e621.net\/posts\/(\d+)", source)
|
|
|
|
if found_post_url:
|
|
return ("id", found_post_url.group(1))
|
|
|
|
if source.isnumeric():
|
|
return ("id", source)
|
|
|
|
return ("error", "No valid post url, id, or md5 hash provided")
|
|
|
|
# Fetches post from e621 by md5 or id
|
|
def get_post(self, post_info):
|
|
try:
|
|
match post_info:
|
|
case ("md5", md5):
|
|
json = self.make_request(
|
|
method="GET",
|
|
url="https://e621.net/posts.json",
|
|
params={"tags": f"md5:{md5}", "limit": 1}
|
|
)
|
|
|
|
if json["posts"] is None or len(json["posts"]) == 0:
|
|
return ("error", f"No post found for md5 {md5}")
|
|
|
|
return ("post", json["posts"][0])
|
|
case ("id", id):
|
|
json = self.make_request(
|
|
method="GET",
|
|
url=f"https://e621.net/posts/{id}.json"
|
|
)
|
|
|
|
if json["post"] is None:
|
|
return ("error", f"No post found for id {id}")
|
|
|
|
return json["post"]
|
|
case _:
|
|
return post_info
|
|
except Exception as e:
|
|
return ("error", str(e))
|
|
|
|
# Formats rating, following the rules and prefixes
|
|
def format_rating(self, post):
|
|
full_map = {
|
|
"e": "explicit",
|
|
"q": "questionable",
|
|
"s": "safe"
|
|
}
|
|
|
|
value = post["rating"]
|
|
|
|
if opts.e621_prompt_rating_format == "full":
|
|
value = full_map[post["rating"]]
|
|
|
|
return f"{opts.e621_prompt_rating_prefix}{value}"
|
|
|
|
# Formats tags from category, excluding tags from the settings, adding prefix and replacing underscores if needed
|
|
def format_category(self, post, category):
|
|
prefix = getattr(opts, f"e621_prompt_{category}_prefix")
|
|
# God I "love" Python. There was a bunch of sets and "-" between them, but we can't use sets
|
|
# due to ordering reasons...
|
|
tags = [tag for tag in (post["tags"][category] or []) if tag not in excluded_tags()]
|
|
|
|
return [f"{prefix}{escape_special_characters(replace_underscores(tag))}" for tag in tags]
|
|
|
|
# Converts post data into tags
|
|
def process_post(self, post, categories):
|
|
match post:
|
|
case ("error", _):
|
|
return post
|
|
case ("post", p):
|
|
post = p
|
|
|
|
result = []
|
|
|
|
for category in categories:
|
|
match category:
|
|
case 'rating':
|
|
result.append(self.format_rating(post))
|
|
case _:
|
|
result = result + self.format_category(post, category)
|
|
|
|
tags_to_append = [tag for tag in appended_tags() if tag not in result]
|
|
result = result + tags_to_append
|
|
|
|
return ("result", ", ".join(result))
|
|
|
|
# Stitches everything together
|
|
def generate_callback(self, source, categories):
|
|
if not categories:
|
|
return "ERROR: No categories selected"
|
|
|
|
result = self.process_post(self.get_post(self.normalize_source(source)), categories)
|
|
|
|
match result:
|
|
case ("error", error):
|
|
return f"ERROR: {error}"
|
|
case ("result", prompt):
|
|
return prompt
|
|
|
|
# Clears form
|
|
def clear_callback(self):
|
|
return default_tags_categories, None, None, None
|
|
|
|
# Calculates hash of the uploaded image, and puts it inside of "source" field
|
|
def image_upload_callback(self, source, image):
|
|
if image is None:
|
|
return source
|
|
|
|
with urlopen(image) as response:
|
|
return md5(response.read()).hexdigest()
|
|
|
|
# Renders ui
|
|
def ui(self, _is_img2img):
|
|
with gr.Group():
|
|
with gr.Accordion(NAME, open=False):
|
|
source = gr.Textbox(label="Source", value="", placeholder="e621 post link / e621 post id / md5 hash of the image")
|
|
|
|
file_source = gr.Image(
|
|
source="upload",
|
|
label="or upload image for hash calculation",
|
|
type="numpy"
|
|
)
|
|
|
|
file_source.upload(self.image_upload_callback, inputs=[source, file_source], outputs=[source], preprocess=False)
|
|
|
|
categories = gr.Dropdown(
|
|
tags_categories_options,
|
|
multiselect=True,
|
|
value=default_tags_categories,
|
|
label="Categories"
|
|
)
|
|
|
|
result = gr.Textbox(value="", label="Result", lines=5, interactive=False, show_copy_button=True)
|
|
|
|
with gr.Row():
|
|
with gr.Column():
|
|
clear_btn = gr.Button("Reset form")
|
|
clear_btn.click(fn=self.clear_callback, inputs=None, outputs=[categories, result, source, file_source])
|
|
with gr.Column():
|
|
generate_btn = gr.Button("Generate", variant="primary")
|
|
generate_btn.click(fn=self.generate_callback, inputs=[source, categories], outputs=[result])
|
|
|
|
# This return is required, because otherwise "path" in ui_config.json and "Defaults" section of the settings
|
|
# would be really wrong
|
|
return [source, file_source, categories, result, clear_btn, generate_btn]
|
|
|
|
# Settings section
|
|
def on_ui_settings():
|
|
default_excluded_tags = ", ".join([
|
|
"comic", "watermark", "text", "sign", "patreon_logo", "internal", "censored", "censored_genitalia", "censored_penis", "censored_pussy",
|
|
"censored_text", "censored_anus", "multiple_poses", "multiple_images", "dialogue", "speech_bubble", "english_text", "dialogue_box",
|
|
"subtitled", "thought_bubble", "cutaway", "conditional_dnp"
|
|
])
|
|
|
|
section = ("e621-prompt", NAME)
|
|
|
|
settings_options = [
|
|
("e621_prompt_username", "", "e621 Username. Not required, but highly preferred"),
|
|
("e621_prompt_api_key", "", "e621 API Key. Not required, but highly preferred"),
|
|
(
|
|
"e621_prompt_user_agent",
|
|
"sd-webui-e621-prompt (nochnoe)",
|
|
"User-Agent for API calls. DO NOT change this line if you don't know what you're doing"
|
|
),
|
|
("e621_prompt_use_proxy", False, "Use proxy when accessing e621"),
|
|
("e621_prompt_proxy_url", "", "HTTPS proxy"),
|
|
("e621_prompt_excluded_tags", default_excluded_tags, "Tags that always should be EXCLUDED from the final result. Comma-separated, with underscores"),
|
|
("e621_prompt_appended_tags", "", "Tags that always should be APPENDED to the final result. Comma-separated, with underscores"),
|
|
("e621_prompt_replace_underscores", True, "Replace underscores with spaces"),
|
|
("e621_prompt_replace_underscores_in_appended", True, "Replace underscores with spaces in the appended tags"),
|
|
("e621_prompt_artist_prefix", "", "Prefix for artists (for example, artist:)"),
|
|
("e621_prompt_meta_prefix", "", "Prefix for meta tags (for example, meta:)"),
|
|
("e621_prompt_species_prefix", "", "Prefix for species tags (for example, species:)"),
|
|
("e621_prompt_character_prefix", "", "Prefix for characters tags (for example, character:)"),
|
|
("e621_prompt_lore_prefix", "", "Prefix for lore tags (for example, lore:)"),
|
|
("e621_prompt_general_prefix", "", "Prefix for general tags (for example, general:)"),
|
|
("e621_prompt_copyright_prefix", "", "Prefix for copyright tags (for example, copyright:)"),
|
|
("e621_prompt_invalid_prefix", "", "Prefix for invalid tags (for example, invalid:)"),
|
|
("e621_prompt_rating_prefix", "rating:", "Prefix for rating (for example, rating:)"),
|
|
("e621_prompt_rating_format", "short", "Rating format (short: e/s/q, full: explicit/safe/questionable)", gr.Dropdown, lambda: {"choices": ["short", "full"]}),
|
|
]
|
|
|
|
for setting_name, *data in settings_options:
|
|
opts.add_option(setting_name, OptionInfo(*data, section=section))
|
|
|
|
script_callbacks.on_ui_settings(on_ui_settings)
|