refactor: refactor pipeline steps, add tests, fixes, doc

This should hopefully make it easier to approach when applying fixes and
adding new features.
main
uwidev 2024-09-26 10:37:42 -07:00
parent 207677984c
commit 15711f71b9
4 changed files with 299 additions and 95 deletions

View File

@ -34,7 +34,7 @@ def format_prompt(*prompts: tuple[dict]):
prompt = pipeline.remove_whitespace_excessive(prompt)
# Replace Spaces and/or underscores, unless disabled
prompt = pipeline.space_to_underscore(prompt, IGNOREUNDERSCORES)
prompt = pipeline.space_to_underscore(prompt, opposite=IGNOREUNDERSCORES)
prompt = pipeline.align_brackets(prompt)
prompt = pipeline.space_and(
prompt

View File

@ -9,6 +9,7 @@ brackets_closing = ")]}>"
re_whitespace = re.compile(r"[^\S\r\n]+") # excludes new lines
re_tokenize = re.compile(r",")
re_tokenize_strip = re.compile(r"\s*,\s*")
re_comma_spacing = re.compile(r",+")
re_brackets_fix_whitespace = re.compile(r"([\(\[{<])\s*|\s*([\)\]}>}])")
re_opposing_brackets = re.compile(r"([)\]}>])([([{<])")
@ -28,6 +29,77 @@ re_pipe = re.compile(r"\s*(\|)\s*")
re_existing_weight = re.compile(r"(?<=:)(\d+.?\d*|\d*.?\d+)(?=[)\]]$)")
def escape_bracket_index(token, symbols, start_index=0):
"""Find the index that supposedly closes this bracket.
Given a token and a set of open bracket symbols, find the index in which
that character escapes the given bracketing such that depth = 0.
"""
token_length = len(token)
open = symbols
close = ""
for s in symbols:
close += brackets_closing[brackets_opening.index(s)]
i = start_index
d = 0
while i < token_length - 1:
if token[i] in open:
d += 1
if token[i] in close:
d -= 1
if d == 0:
return i
i += 1
return i
def get_weight(
prompt: str,
map_gradient: list,
map_depth: list,
map_brackets: list,
pos: int,
ctv: int,
gradient_search: str,
is_square_brackets: bool = False,
):
"""Determine a weight given the start of its brackets.
TODO: I'm pretty sure this disallows mixing of square brackets and
parenthesis and will evaluate incorrectly. Take this into account...
Return a tuple containing:
- where to insert the weight
- the weight itself
- how many consecutive brackets there are(?)
Return weight=0 if bracket was recognized as prompt editing, alternation,
or composable.
"""
# CURRENTLY DOES NOT TAKE INTO ACCOUNT COMPOSABLE?? DO WE EVEN NEED TO?
# E.G. [a AND B :1.2] == (a AND B:1.1) != (a AND B:1.1) ????
while pos + ctv <= len(prompt):
if ctv == 0:
return prompt, 0, 1
a, b = pos, pos + ctv
if prompt[a] in ":|" and is_square_brackets:
if map_depth[-2] == map_depth[a]:
return prompt, 0, 1
if map_depth[a] in gradient_search:
gradient_search = gradient_search.replace(map_depth[a], "")
ctv -= 1
elif map_gradient[a:b] == "v" * ctv and map_depth[a - 1 : b] == gradient_search:
return a, calculate_weight(ctv, is_square_brackets=is_square_brackets), ctv
elif "v" == map_gradient[a] and map_depth[a - 1 : b - 1] in gradient_search:
narrowing = map_gradient[a:b].count("v")
gradient_search = gradient_search[narrowing:]
ctv -= 1
pos += 1
msg = f"Somehow weight index searching has gone outside of prompt length with prompt: {prompt}"
raise Exception(msg)
def get_bracket_closing(c: str):
return brackets_closing[brackets_opening.find(c)]
@ -40,8 +112,8 @@ def normalize_characters(data: str):
return unicodedata.normalize("NFKC", data)
def tokenize(data: str) -> list:
return re_tokenize.split(data)
def tokenize(data: str, *, strip:bool = False) -> list:
return re_tokenize_strip.split(data) if strip else re_tokenize.split(data)
def remove_whitespace_excessive(prompt: str):
@ -49,6 +121,11 @@ def remove_whitespace_excessive(prompt: str):
def align_brackets(prompt: str):
"""Push opening of brackets to a character.
e.g.
'( foo)' -> '(foo)'
"""
def helper(match: re.Match):
return match.group(1) or match.group(2)
@ -56,6 +133,12 @@ def align_brackets(prompt: str):
def space_and(prompt: str):
"""Put proper spacing around AND for composable diffusion.
Also known as prompt composition.
e.g.
'a ANDb' -> 'a AND b'
"""
def helper(match: re.Match):
return " ".join(match.groups())
@ -63,21 +146,38 @@ def space_and(prompt: str):
def align_colons(prompt: str):
"""Push characters into colons from both sides.
Interestingly, these two generate the same image.
'a :1.2' == 'a:1.2' == 'a: 1.2'
There does not appear to be any special interactions with AND for
composable diffusion...
e.g.
'a : b' -> 'a:b'
"""
def normalize(match: re.Match):
return match.group(1)
def composite(match: re.Match):
return " " + match.group(1)
# def composite(match: re.Match):
# return " " + match.group(1)
#
# def composite_end(match: re.Match):
# return " " + match.group(1)
def composite_end(match: re.Match):
return " " + match.group(1)
ret = re_colon_spacing.sub(normalize, prompt)
ret = re_colon_spacing_composite.sub(composite, ret)
return re_colon_spacing_comp_end.sub(composite_end, ret)
return re_colon_spacing.sub(normalize, prompt)
# ret = re_colon_spacing.sub(normalize, prompt)
# ret = re_colon_spacing_composite.sub(composite, ret)
# return re_colon_spacing_comp_end.sub(composite_end, ret)
def align_commas(prompt: str, *, do_it: bool = True):
"""Align commas like natural language.
TODO: Tokenizer automatically strips whitespace when splitting at comma.
Take that into account and verify the functionality of this step.
"""
if not do_it:
return prompt
@ -103,43 +203,67 @@ def remove_networks(tokens: list):
def remove_mismatched_brackets(prompt: str):
stack = []
pos = []
ret = ""
"""Remove unmatched brackets.
A closing bracket should be able to find an matching unclosed bracket.
If it finds a nonmatching unclosed bracket, that bracket and this
bracket are invalid.
"""
invalid_brackets = []
invalid_at = []
# Find invalid brackets
for i, c in enumerate(prompt):
if c in brackets_opening:
stack.append(c)
pos.append(i)
ret += c
elif c in brackets_closing:
if not stack:
continue
if stack[-1] == brackets_opening[brackets_closing.index(c)]:
stack.pop()
pos.pop()
ret += c
else:
ret += c
invalid_brackets.append(c)
invalid_at.append(i)
while stack:
bracket = stack.pop()
p = pos.pop()
ret = ret[:p] + ret[p + 1 :]
elif c in brackets_closing:
if not invalid_brackets:
invalid_brackets.append(c)
invalid_at.append(i)
# Look for the immediate unmatched opening bracket
if invalid_brackets[-1] == brackets_opening[brackets_closing.index(c)]:
invalid_brackets.pop()
invalid_at.pop()
else:
invalid_brackets.append(c)
invalid_at.append(i)
if not invalid_brackets:
return prompt
# Remove invalid brackets
ret = ""
last_p = 0
while invalid_brackets:
bracket = invalid_brackets.pop(0)
p = invalid_at.pop(0)
ret += prompt[last_p:p]
last_p = p+1
ret += prompt[last_p:]
return ret
def space_bracekts(prompt: str):
"""Space adjacent closing-opening brackets.
e.g. ')(' -> '()'
"""
def helper(match: re.Match):
# print(' '.join(match.groups()))
return " ".join(match.groups())
# print(prompt)
return re_opposing_brackets.sub(helper, prompt)
def align_alternating(prompt: str):
"""Push alternating symbol | together with words.
e.g.
'a |b' -> 'a|b'
"""
def helper(match: re.Match):
return match.group(1)
@ -333,76 +457,28 @@ def calculate_weight(d: str, *, is_square_brackets: bool):
return 1 / 1.1 ** int(d) if is_square_brackets else 1 * 1.1 ** int(d)
def get_weight(
prompt: str,
map_gradient: list,
map_depth: list,
map_brackets: list,
pos: int,
ctv: int,
gradient_search: str,
is_square_brackets: bool = False,
):
"""Returns 0 if bracket was recognized as prompt editing, alternation, or composable."""
# CURRENTLY DOES NOT TAKE INTO ACCOUNT COMPOSABLE?? DO WE EVEN NEED TO?
# E.G. [a AND B :1.2] == (a AND B:1.1) != (a AND B:1.1) ????
while pos + ctv <= len(prompt):
if ctv == 0:
return prompt, 0, 1
a, b = pos, pos + ctv
if prompt[a] in ":|" and is_square_brackets:
if map_depth[-2] == map_depth[a]:
return prompt, 0, 1
if map_depth[a] in gradient_search:
gradient_search = gradient_search.replace(map_depth[a], "")
ctv -= 1
elif map_gradient[a:b] == "v" * ctv and map_depth[a - 1 : b] == gradient_search:
return a, calculate_weight(ctv, is_square_brackets=is_square_brackets), ctv
elif "v" == map_gradient[a] and map_depth[a - 1 : b - 1] in gradient_search:
narrowing = map_gradient[a:b].count("v")
gradient_search = gradient_search[narrowing:]
ctv -= 1
pos += 1
msg = f"Somehow weight index searching has gone outside of prompt length with prompt: {prompt}"
raise Exception(msg)
def space_to_underscore(prompt: str, conv_space_to_underscore = True):
# We need to look ahead and ignore any spaces/underscores within network tokens
# INPUT <lora:chicken butt>, multiple subjects
# OUTPUT <lora:chicken butt>, multiple_subjects
def space_to_underscore(prompt: str, *, opposite = True):
"""Replace space with underscore or vice versa.
It's a but funky right now because it uses the tokenizer to chunk for sub.
Currently(ish), the tokenizer does not strip whitespace, so any existing
'foo, bar' is split into ('foo', ' bar'), and will result in 'foo,_bar'.
This has been patched by requiring the match to be surrounded with a
character, but I'm sure there's better solutions. It will work for now.
"""
match = (
r"(?<!BREAK) +(?!BREAK|[^<]*>)"
if conv_space_to_underscore
else r"(?<!BREAK|_)_(?!_|BREAK|[^<]*>)"
r"(?<!BREAK)(?<=\w) +(?=\w)(?!BREAK)(?![^<]*>)"
if opposite
else r"(?<!BREAK)(?<=\w)_+(?=\w)(?!BREAK)(?![^<]*>)"
)
replace = "_" if conv_space_to_underscore else " "
replace = "_" if opposite else " "
tokens: str = tokenize(prompt)
print(tokens)
return ",".join(map(lambda t: re.sub(match, replace, t), tokens))
def escape_bracket_index(token, symbols, start_index=0):
# Given a token and a set of open bracket symbols, find the index in which that character
# escapes the given bracketing such that depth = 0.
token_length = len(token)
open = symbols
close = ""
for s in symbols:
close += brackets_closing[brackets_opening.index(s)]
i = start_index
d = 0
while i < token_length - 1:
if token[i] in open:
d += 1
if token[i] in close:
d -= 1
if d == 0:
return i
i += 1
return i

0
tests/__init__.py Normal file
View File

128
tests/test_pipeline.py Normal file
View File

@ -0,0 +1,128 @@
"""Unit testing for prompt transformation pipeline."""
import sys
from pathlib import Path
from typing import TYPE_CHECKING
import pytest
from scripts import prompt_formatting_pipeline as pipeline
def test_get_bracket_closing():
# Test for each opening bracket
assert pipeline.get_bracket_closing('(') == ')'
assert pipeline.get_bracket_closing('[') == ']'
assert pipeline.get_bracket_closing('{') == '}'
assert pipeline.get_bracket_closing('<') == '>'
# # Test for invalid input (not an opening bracket)
# with pytest.raises(ValueError):
# pipeline.get_bracket_closing('a')
#
# # Test for empty string (should raise an error)
# with pytest.raises(ValueError):
# pipeline.get_bracket_closing('')
#
# # Test for input that's not a single character
# with pytest.raises(TypeError):
# pipeline.get_bracket_closing('(())')
def test_get_bracket_opening():
# Test for each closing bracket
assert pipeline.get_bracket_opening(')') == '('
assert pipeline.get_bracket_opening(']') == '['
assert pipeline.get_bracket_opening('}') == '{'
assert pipeline.get_bracket_opening('>') == '<'
# # Test for invalid input (not a closing bracket)
# with pytest.raises(ValueError):
# pipeline.get_bracket_opening('a')
#
# # Test for empty string (should raise an error)
# with pytest.raises(ValueError):
# pipeline.get_bracket_opening('')
#
# # Test for input that's not a single character
# with pytest.raises(TypeError):
# pipeline.get_bracket_opening('()')
def test_normalize_characters():
assert pipeline.normalize_characters('アイウエオ') == 'アイウエオ' # Full-width to half-width
assert pipeline.normalize_characters('𝓣𝓮𝓼𝓽') == 'Test' # Fraktur to regular
assert pipeline.normalize_characters('abc') == 'abc' # No change
assert pipeline.normalize_characters('Hello, 世界!') == 'Hello, 世界!' # Mixed characters
def test_tokenize():
assert pipeline.tokenize('a,b,c') == ['a', 'b', 'c']
assert pipeline.tokenize('1,2,3,4') == ['1', '2', '3', '4']
assert pipeline.tokenize('apple,,banana') == ['apple', '', 'banana']
assert pipeline.tokenize('hello') == ['hello']
assert pipeline.tokenize('a,,b,,c') == ['a', '', 'b', '', 'c']
def test_align_brackets():
assert pipeline.align_brackets('( foo)') == '(foo)'
assert pipeline.align_brackets('[ bar ]') == '[bar]'
assert pipeline.align_brackets('{ test }') == '{test}'
assert pipeline.align_brackets('< example >') == '<example>'
assert pipeline.align_brackets('( [ { < content > } ] )') == '([{<content>}])'
def test_space_and():
assert pipeline.space_and('a AND b') == 'a AND b'
assert pipeline.space_and('foo ANDbar') == 'foo AND bar'
assert pipeline.space_and('hello AND world') == 'hello AND world'
assert pipeline.space_and('test ANDexample') == 'test AND example'
assert pipeline.space_and('a AND b AND c') == 'a AND b AND c'
assert pipeline.space_and('aANDbANDc') == 'a AND b AND c'
def test_align_colons():
assert pipeline.align_colons('key: value') == 'key:value'
assert pipeline.align_colons('foo : bar') == 'foo:bar'
assert pipeline.align_colons('test: example') == 'test:example'
assert pipeline.align_colons('name: John AND age: 30') == 'name:John AND age:30'
assert pipeline.align_colons('foo bar:1.0 AND zee') == 'foo bar:1.0 AND zee'
def test_align_commas():
assert pipeline.align_commas('a, b, c') == 'a, b, c'
assert pipeline.align_commas(' foo , bar , baz ') == 'foo, bar, baz'
assert pipeline.align_commas('test,example') == 'test, example'
assert pipeline.align_commas(' item1, item2 , item3 ') == 'item1, item2, item3'
assert pipeline.align_commas(' , a , b , c , ') == 'a, b, c'
def test_remove_mismatched_brackets():
assert pipeline.remove_mismatched_brackets('(a[b]c)') == '(a[b]c)'
assert pipeline.remove_mismatched_brackets('a(b)c') == 'a(b)c'
assert pipeline.remove_mismatched_brackets('a(b]c') == 'abc'
assert pipeline.remove_mismatched_brackets('[(a+b)]') == '[(a+b)]'
assert pipeline.remove_mismatched_brackets('a{b[c}d]') == 'abcd'
def test_space_bracekts():
assert pipeline.space_bracekts(')(') == ') ('
assert pipeline.space_bracekts('][}{') == '] [} {'
assert pipeline.space_bracekts('foo(bar)baz') == 'foo(bar)baz'
assert pipeline.space_bracekts('a(b)c[d]e{f}g') == 'a(b)c[d]e{f}g'
assert pipeline.space_bracekts(')a[b]{c}') == ')a[b] {c}'
def test_align_alternating():
assert pipeline.align_alternating('a |b') == 'a|b'
assert pipeline.align_alternating('foo |bar |baz') == 'foo|bar|baz'
assert pipeline.align_alternating('test | example') == 'test|example'
assert pipeline.align_alternating('hello | world') == 'hello|world'
assert pipeline.align_alternating('a | b | c') == 'a|b|c'
def test_bracket_to_weights():
assert pipeline.bracket_to_weights('(a)') == '(a:1.10)'
assert pipeline.bracket_to_weights('((a))') == '(a:1.21)'
assert pipeline.bracket_to_weights('((a, b))') == '(a, b:1.21)'
assert pipeline.bracket_to_weights('(a, (b))') == '(a, (b:1.10):1.10)'
assert pipeline.bracket_to_weights('((a), b)') == '((a:1.10), b:1.10)'
assert pipeline.bracket_to_weights('((a), ((b)))') == '((a:1.10), (b:1.21):1.10)'
def test_space_to_underscore():
assert pipeline.space_to_underscore('<lora:chicken butt>, multiple subjects') == '<lora:chicken butt>, multiple_subjects'
assert pipeline.space_to_underscore('one two three') == 'one_two_three'
assert pipeline.space_to_underscore('this is a test') == 'this_is_a_test'
assert pipeline.space_to_underscore('<embed:foo bar>, baz') == '<embed:foo bar>, baz'
assert pipeline.space_to_underscore('some_var_name', opposite=False) == 'some var name'