Overhauled technique for cleaning up the prompt and formatting it.

More or less a compete rewrite from a proof of concept to something more consistnt and stable. Now takes into account things like prompt editing, alternating words, and composable diffusion. In the mean time, however, it no longer separates weighted terms. Until I can figure out how to deal with infinitely nested weights and other weird combinations (recursion stuff), it's on hold.

The separation of isolated weights will come later as a togglable option, perhaps with other settings.
pull/18/head v0.2
blongty 2023-04-26 23:30:58 -07:00
parent 43c47c15fa
commit 1695f72821
2 changed files with 252 additions and 66 deletions

5
install.py Normal file
View File

@ -0,0 +1,5 @@
import launch
if not launch.is_installed("regex"):
print("Installing requirements for Prompt Formatter")
launch.run_pip("install regex", "support for variable lookbehind")

View File

@ -1,58 +1,54 @@
import gradio as gr
import re
import regex as re
import unicodedata
from modules import script_callbacks
bracketing = '([{<)]}>'
brackets_opening = '([{<'
brackets_closing = ')]}>'
none = r'(?:\\[()\[\]{}<>]|[^,(){}\[\]{}<>])+'
paren = r'\(+' + none + r'\)+'
square = r'\[+' + none + r'\]+'
curly = r'{+' + none + r'}+'
angle = r'<+' + none + r'>+'
brackets_opening = '([{<'
brackets_closing = ')]}>'
re_tokenize = re.compile('|'.join([none, paren, square, curly, angle]))
re_brackets_fix_whitespace = re.compile(r'\s*[^\\\S]([\(<{\)>}:]+)\s*')
# base = r'(?:\\[()\[\]{}<>]|[^,(){}\[\]{}<>])+'
# paren = r'\(+' + base + r'\)+'
# square = r'\[+' + base + r'\]+'
# curly = r'{+' + base + r'}+'
# angle = r'<+' + base + r'>+'
# re_tokenize = re.compile('|'.join([base, paren, square, curly, angle]))
re_tokenize = re.compile(r',')
re_brackets_fix_whitespace = re.compile(r'([\(\[{<])\s*|\s*([\)\]}>}])')
re_opposing_brackets = re.compile(r'([)\]}>])([([{<])')
re_networks = re.compile(r'<.+?>')
re_brackets_open = re.compile(r'[(\[{]+')
# re_colon_spacing_composite = re.compile(r'(?P<A>[^:]*?)\s*?(?P<COLON>:)\s*?(?P<B>\S*)(?P<S>\s*)(?(S)\s*?)(?P<AND>AND)')
re_colon_spacing_composite = re.compile(r'\s*(:)\s*(?=\d*?\.?\d*?\s*?AND)')
re_colon_spacing = re.compile(r'\s*(:)\s*')
# re_colon_spacing = re.compile(r'(?P<A>[^:]*?)\s*?(?P<COLON>:)\s*?(?P<B>\S+)(?P<S>\s*)(?(S)\s*?)')
re_colon_spacing_comp_end = re.compile(r'(?<=AND[^:]*?)(:)(?=[^:]*$)')
re_comma_spacing = re.compile(r',+')
re_paren_weights_exist = re.compile(r'\(.*(?<!:):\d.?\d*\)+')
re_is_prompt_editing = re.compile(r'\[.*:.*\]')
re_is_prompt_alternating = re.compile(r'\[.*|.*\]')
re_is_wildcard = re.compile(r'{.*}')
re_AND = re.compile(r'(.*?)\s*(AND)\s*(.*?)')
re_alternating = re.compile(r'\s*(\|)\s*')
ui_prompts = []
def get_closing(c: str):
return bracketing[bracketing.find(c) + len(bracketing)//2]
def get_bracket_closing(c: str):
return brackets_closing[brackets_opening.find(c)]
def fix_bracketing(token: str):
# token should always have at least 1 matching pair
# tokenizer() should always ensure that's always the case
if not re.match(r'[\(\[{<]', token):
return token
stack = []
ret = list(token)
opening = ret[0]
closing = get_closing(opening)
for i, c in enumerate(token):
if token[i] == opening:
stack.append(i)
elif token[i] == closing:
if stack:
stack.pop()
else:
ret[i] = ''
while stack:
ret.pop(stack.pop())
return ''.join(ret)
def get_bracket_opening(c: str):
return brackets_opening[brackets_closing.find(c)]
def normalize(data: str):
def normalize_characters(data: str):
return unicodedata.normalize("NFKC", data)
@ -60,34 +56,110 @@ def tokenize(data: str):
return re_tokenize.findall(data)
def remove_whitespace(tokens: list):
pruned = [' '.join(token.strip().split()) for token in tokens]
pruned = list(filter(None, pruned))
return list(map(lambda token : re_brackets_fix_whitespace.sub(r"\1", token), pruned))
def remove_whitespace_excessive(prompt: str):
return ' '.join(prompt.split())
# pruned = [' '.join(token.strip().split()) for token in tokens]
# pruned = list(filter(None, pruned))
# return pruned
def min_normalized_brackets(tokens: list):
return list(map(fix_bracketing, tokens))
def align_brackets(prompt: str):
def helper(match: re.Match):
return match.group(1) or match.group(2)
return re_brackets_fix_whitespace.sub(helper, prompt)
# return list(map(lambda token : re_brackets_fix_whitespace.sub(helper, token), tokens))
def brackets_to_weights(tokens: list):
return list(map(token_bracket_to_weight, tokens))
def space_AND(prompt: str):
def helper(match: re.Match):
return ' '.join(match.groups())
return re_AND.sub(helper, prompt)
def token_bracket_to_weight(token:str):
# If weighting already exists, just get rid of excess brackets
if not re_brackets_open.match(token):
return token
def align_colons(prompt: str):
def normalize(match: re.Match):
return match.group(1)
brackets = re_brackets_open.match(token).group(0)
power = len(brackets) if brackets[0] in '{(' else -len(brackets)
depth = abs(power)
def composite(match: re.Match):
return ' ' + match.group(1)
if re.search(r':\d+.?\d*', token):
return f'({token[depth:len(token)-depth]})'
def composite_end(match: re.Match):
print(f'match: {match}')
return ' ' + match.group(1)
weight = 1.1 ** power
return '(' + token[depth:len(token)-depth] + ('' if token[-depth-1:-depth] == ':' else ':') + f'{weight:.2f})'
ret = re_colon_spacing.sub(normalize, prompt)
ret = re_colon_spacing_composite.sub(composite, ret)
ret = re_colon_spacing_comp_end.sub(composite_end, ret)
return ret
# def helper(match: re.Match):
# if match.group('AND'):
# return f"{match.group('A')} :{match.group('B')} AND"
# return f"{match.group('A')}:{match.group('B')}"
# def edge_case(match: re.Match):
# # edge case where if composite with weight at end, fix alignment
# if match.group('AND'):
# return ' '.join(match.group('AND', 2)) + ' ' + ''.join(match.group(3, 4))
# ret = re_colon_spacing.sub(helper, prompt)
# return re_colon_spacing_comp_end.sub(edge_case, ret)
# def fix_ending_compositing(s: str):
# # edge case where if composite, weight isn't followed by AND, need to
# # check backwards and check if needs to fix alignment
# match = re_colon_spacing_comp_end.match(s)
# if match.group('AND'):
# return ' '.join(match.group(1, 2)) + ' ' + ''.join(match.group(3, 4))
# ret = re_colon_spacing.sub(helper, prompt)
# return fix_ending_compositing(ret)
def align_commas(prompt: str):
split = re_comma_spacing.split(prompt)
split = map(str.strip, split)
split = filter(None, split)
return ', '.join(split)
def brackets_to_weights(tokens: list, power: int = 0):
print(tokens)
ret = []
re_opening_paren = re.compile('\([^\(]')
re_opening_square = re.compile('\[[^\[]')
for token in tokens:
if re_opening_paren.match(token):
pass
# Assumes colons have already been spaced corretly
# def normalize(token:str):
# pass
# if not re_brackets_open.match(token):
# return token
# brackets = re_brackets_open.match(token).group(0)
# power = len(brackets) if not brackets[0] == '[' else -len(brackets)
# depth = abs(power)
# if (re_paren_weights_exist.search(token) or
# re_is_prompt_editing.search(token) or
# re_is_wildcard.search(token) or
# re_is_prompt_alternating.search(token)):
# return str(brackets[0] + token[depth:len(token)-depth] + get_bracket_closing(brackets[0])) # just return normalized bracketing
# weight = 1.1 ** power
# return '(' + token[depth:len(token)-depth] + ('' if token[-depth-1:-depth] == ':' else ':') + f'{weight:.2f}' + ')'
# return list(map(normalize, tokens))
def extract_networks(tokens: list):
@ -98,18 +170,127 @@ def remove_networks(tokens: list):
return list(filter(lambda token : not re_networks.match(token), tokens))
def remove_mismatched_brackets(prompt: str):
stack = []
pos = []
ret = ''
for i, c in enumerate(prompt):
if c in brackets_opening:
stack.append(c)
pos.append(i)
ret += c
elif c in brackets_closing:
if not stack:
continue
if stack[-1] == brackets_opening[brackets_closing.index(c)]:
stack.pop()
pos.pop()
ret += c
else:
ret += c
while stack:
bracket = stack.pop()
p = pos.pop()
ret = ret[:p] + ret[p+1:]
return ret
# Tokenizing is extremely tedious and perhaps unecessary...
# def tokenize_nested(prompt: str):
# """
# Tokenizes the prompt based on commas, brackets, and parenthesis.
# """
# result = []
# re_dividers = re.compile(r'(?<!\\)([\(\)\[\],<>{}])')
# pos = 0
# while pos < len(prompt):
# match = re_dividers.search(prompt, pos)
# # we know we're at the end of the string when we can't match
# if match is None:
# substring = prompt[pos:].strip()
# if substring:
# result.append(substring)
# break
# # add up to the previous token up to excluding our matched position
# substring = prompt[pos:match.start()].strip()
# if substring:
# result.append(prompt[pos:match.start()].strip())
# if prompt[match.start()] in '}>': # brackets don't get added, so this corrects for it
# result[-1] = get_bracket_opening(prompt[match.start()]) + result[-1] + prompt[match.start()]
# # if comma, move pos past it
# if prompt[match.start()] in ',<>{}':
# pos = match.end()
# # finally deal with real nested stuff
# elif prompt[match.start()] in '[(':
# nested_result, length = tokenize_nested(prompt[match.end():]) # recurses with s, the end of match onwards
# nested_result[0] = prompt[match.start()] + nested_result[0]
# result.append(nested_result)
# pos = match.end() + length
# elif prompt[match.start()] in '])':
# result[-1] = result[-1] + prompt[match.start()] # return from recurse, including the
# return result, match.end() # end of the match to correct position
# return result
# def flatten_tokens(tokens: list):
# ret = []
# for token in tokens:
# if isinstance(token, list):
# ret.extend(flatten_tokens(token))
# else:
# ret.append(token)
# return ret
def space_bracekts(prompt: str):
def helper(match : re.Match):
print(' '.join(match.groups()))
return ' '.join(match.groups())
print(prompt)
return re_opposing_brackets.sub(helper, prompt)
def align_alternating(prompt:str):
def helper(match: re.Match):
return match.group(1)
return re_alternating.sub(helper, prompt)
def format_prompt(*prompts: list):
ret = []
for prompt in prompts:
prompt_norm = normalize(prompt)
tokens = tokenize(prompt_norm)
tokens = remove_whitespace(tokens)
tokens = min_normalized_brackets(tokens)
tokens = brackets_to_weights(tokens)
networks = extract_networks(tokens)
tokens = remove_networks(tokens)
tokens.extend(networks)
ret.append(', '.join(list(tokens)))
# Clean up the string
prompt = normalize_characters(prompt)
prompt = remove_mismatched_brackets(prompt)
# Clean up whitespace for cool beans
prompt = remove_whitespace_excessive(prompt)
prompt = align_brackets(prompt)
prompt = space_AND(prompt) # for proper compositing alignment on colons
prompt = space_bracekts(prompt)
prompt = align_colons(prompt)
prompt = align_commas(prompt)
prompt = align_alternating(prompt)
# Further processing for usability
# prompt = brackets_to_weights(prompt)
# networks = extract_networks(tokens)
# tokens = remove_networks(tokens)
# tokens.extend(networks)
# tokens = flatten_tokens(tokens)
# ret.append(', '.join(list(tokens)))
ret.append(prompt)
return ret