refactor: refactor pipeline steps, add tests, fixes, doc
This should hopefully make it easier to approach when applying fixes and adding new features.main
parent
207677984c
commit
15711f71b9
|
|
@ -34,7 +34,7 @@ def format_prompt(*prompts: tuple[dict]):
|
|||
prompt = pipeline.remove_whitespace_excessive(prompt)
|
||||
|
||||
# Replace Spaces and/or underscores, unless disabled
|
||||
prompt = pipeline.space_to_underscore(prompt, IGNOREUNDERSCORES)
|
||||
prompt = pipeline.space_to_underscore(prompt, opposite=IGNOREUNDERSCORES)
|
||||
prompt = pipeline.align_brackets(prompt)
|
||||
prompt = pipeline.space_and(
|
||||
prompt
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ brackets_closing = ")]}>"
|
|||
|
||||
re_whitespace = re.compile(r"[^\S\r\n]+") # excludes new lines
|
||||
re_tokenize = re.compile(r",")
|
||||
re_tokenize_strip = re.compile(r"\s*,\s*")
|
||||
re_comma_spacing = re.compile(r",+")
|
||||
re_brackets_fix_whitespace = re.compile(r"([\(\[{<])\s*|\s*([\)\]}>}])")
|
||||
re_opposing_brackets = re.compile(r"([)\]}>])([([{<])")
|
||||
|
|
@ -28,6 +29,77 @@ re_pipe = re.compile(r"\s*(\|)\s*")
|
|||
re_existing_weight = re.compile(r"(?<=:)(\d+.?\d*|\d*.?\d+)(?=[)\]]$)")
|
||||
|
||||
|
||||
def escape_bracket_index(token, symbols, start_index=0):
|
||||
"""Find the index that supposedly closes this bracket.
|
||||
|
||||
Given a token and a set of open bracket symbols, find the index in which
|
||||
that character escapes the given bracketing such that depth = 0.
|
||||
"""
|
||||
token_length = len(token)
|
||||
open = symbols
|
||||
close = ""
|
||||
for s in symbols:
|
||||
close += brackets_closing[brackets_opening.index(s)]
|
||||
|
||||
i = start_index
|
||||
d = 0
|
||||
while i < token_length - 1:
|
||||
if token[i] in open:
|
||||
d += 1
|
||||
if token[i] in close:
|
||||
d -= 1
|
||||
if d == 0:
|
||||
return i
|
||||
i += 1
|
||||
|
||||
return i
|
||||
|
||||
def get_weight(
|
||||
prompt: str,
|
||||
map_gradient: list,
|
||||
map_depth: list,
|
||||
map_brackets: list,
|
||||
pos: int,
|
||||
ctv: int,
|
||||
gradient_search: str,
|
||||
is_square_brackets: bool = False,
|
||||
):
|
||||
"""Determine a weight given the start of its brackets.
|
||||
|
||||
TODO: I'm pretty sure this disallows mixing of square brackets and
|
||||
parenthesis and will evaluate incorrectly. Take this into account...
|
||||
|
||||
Return a tuple containing:
|
||||
- where to insert the weight
|
||||
- the weight itself
|
||||
- how many consecutive brackets there are(?)
|
||||
|
||||
Return weight=0 if bracket was recognized as prompt editing, alternation,
|
||||
or composable.
|
||||
"""
|
||||
# CURRENTLY DOES NOT TAKE INTO ACCOUNT COMPOSABLE?? DO WE EVEN NEED TO?
|
||||
# E.G. [a AND B :1.2] == (a AND B:1.1) != (a AND B:1.1) ????
|
||||
while pos + ctv <= len(prompt):
|
||||
if ctv == 0:
|
||||
return prompt, 0, 1
|
||||
a, b = pos, pos + ctv
|
||||
if prompt[a] in ":|" and is_square_brackets:
|
||||
if map_depth[-2] == map_depth[a]:
|
||||
return prompt, 0, 1
|
||||
if map_depth[a] in gradient_search:
|
||||
gradient_search = gradient_search.replace(map_depth[a], "")
|
||||
ctv -= 1
|
||||
elif map_gradient[a:b] == "v" * ctv and map_depth[a - 1 : b] == gradient_search:
|
||||
return a, calculate_weight(ctv, is_square_brackets=is_square_brackets), ctv
|
||||
elif "v" == map_gradient[a] and map_depth[a - 1 : b - 1] in gradient_search:
|
||||
narrowing = map_gradient[a:b].count("v")
|
||||
gradient_search = gradient_search[narrowing:]
|
||||
ctv -= 1
|
||||
pos += 1
|
||||
|
||||
msg = f"Somehow weight index searching has gone outside of prompt length with prompt: {prompt}"
|
||||
raise Exception(msg)
|
||||
|
||||
def get_bracket_closing(c: str):
|
||||
return brackets_closing[brackets_opening.find(c)]
|
||||
|
||||
|
|
@ -40,8 +112,8 @@ def normalize_characters(data: str):
|
|||
return unicodedata.normalize("NFKC", data)
|
||||
|
||||
|
||||
def tokenize(data: str) -> list:
|
||||
return re_tokenize.split(data)
|
||||
def tokenize(data: str, *, strip:bool = False) -> list:
|
||||
return re_tokenize_strip.split(data) if strip else re_tokenize.split(data)
|
||||
|
||||
|
||||
def remove_whitespace_excessive(prompt: str):
|
||||
|
|
@ -49,6 +121,11 @@ def remove_whitespace_excessive(prompt: str):
|
|||
|
||||
|
||||
def align_brackets(prompt: str):
|
||||
"""Push opening of brackets to a character.
|
||||
|
||||
e.g.
|
||||
'( foo)' -> '(foo)'
|
||||
"""
|
||||
def helper(match: re.Match):
|
||||
return match.group(1) or match.group(2)
|
||||
|
||||
|
|
@ -56,6 +133,12 @@ def align_brackets(prompt: str):
|
|||
|
||||
|
||||
def space_and(prompt: str):
|
||||
"""Put proper spacing around AND for composable diffusion.
|
||||
|
||||
Also known as prompt composition.
|
||||
e.g.
|
||||
'a ANDb' -> 'a AND b'
|
||||
"""
|
||||
def helper(match: re.Match):
|
||||
return " ".join(match.groups())
|
||||
|
||||
|
|
@ -63,21 +146,38 @@ def space_and(prompt: str):
|
|||
|
||||
|
||||
def align_colons(prompt: str):
|
||||
"""Push characters into colons from both sides.
|
||||
|
||||
Interestingly, these two generate the same image.
|
||||
'a :1.2' == 'a:1.2' == 'a: 1.2'
|
||||
|
||||
There does not appear to be any special interactions with AND for
|
||||
composable diffusion...
|
||||
|
||||
e.g.
|
||||
'a : b' -> 'a:b'
|
||||
"""
|
||||
def normalize(match: re.Match):
|
||||
return match.group(1)
|
||||
|
||||
def composite(match: re.Match):
|
||||
return " " + match.group(1)
|
||||
# def composite(match: re.Match):
|
||||
# return " " + match.group(1)
|
||||
#
|
||||
# def composite_end(match: re.Match):
|
||||
# return " " + match.group(1)
|
||||
|
||||
def composite_end(match: re.Match):
|
||||
return " " + match.group(1)
|
||||
|
||||
ret = re_colon_spacing.sub(normalize, prompt)
|
||||
ret = re_colon_spacing_composite.sub(composite, ret)
|
||||
return re_colon_spacing_comp_end.sub(composite_end, ret)
|
||||
return re_colon_spacing.sub(normalize, prompt)
|
||||
# ret = re_colon_spacing.sub(normalize, prompt)
|
||||
# ret = re_colon_spacing_composite.sub(composite, ret)
|
||||
# return re_colon_spacing_comp_end.sub(composite_end, ret)
|
||||
|
||||
|
||||
def align_commas(prompt: str, *, do_it: bool = True):
|
||||
"""Align commas like natural language.
|
||||
|
||||
TODO: Tokenizer automatically strips whitespace when splitting at comma.
|
||||
Take that into account and verify the functionality of this step.
|
||||
"""
|
||||
if not do_it:
|
||||
return prompt
|
||||
|
||||
|
|
@ -103,43 +203,67 @@ def remove_networks(tokens: list):
|
|||
|
||||
|
||||
def remove_mismatched_brackets(prompt: str):
|
||||
stack = []
|
||||
pos = []
|
||||
ret = ""
|
||||
"""Remove unmatched brackets.
|
||||
|
||||
A closing bracket should be able to find an matching unclosed bracket.
|
||||
If it finds a nonmatching unclosed bracket, that bracket and this
|
||||
bracket are invalid.
|
||||
"""
|
||||
invalid_brackets = []
|
||||
invalid_at = []
|
||||
|
||||
# Find invalid brackets
|
||||
for i, c in enumerate(prompt):
|
||||
if c in brackets_opening:
|
||||
stack.append(c)
|
||||
pos.append(i)
|
||||
ret += c
|
||||
elif c in brackets_closing:
|
||||
if not stack:
|
||||
continue
|
||||
if stack[-1] == brackets_opening[brackets_closing.index(c)]:
|
||||
stack.pop()
|
||||
pos.pop()
|
||||
ret += c
|
||||
else:
|
||||
ret += c
|
||||
invalid_brackets.append(c)
|
||||
invalid_at.append(i)
|
||||
|
||||
while stack:
|
||||
bracket = stack.pop()
|
||||
p = pos.pop()
|
||||
ret = ret[:p] + ret[p + 1 :]
|
||||
elif c in brackets_closing:
|
||||
if not invalid_brackets:
|
||||
invalid_brackets.append(c)
|
||||
invalid_at.append(i)
|
||||
|
||||
# Look for the immediate unmatched opening bracket
|
||||
if invalid_brackets[-1] == brackets_opening[brackets_closing.index(c)]:
|
||||
invalid_brackets.pop()
|
||||
invalid_at.pop()
|
||||
else:
|
||||
invalid_brackets.append(c)
|
||||
invalid_at.append(i)
|
||||
|
||||
if not invalid_brackets:
|
||||
return prompt
|
||||
|
||||
# Remove invalid brackets
|
||||
ret = ""
|
||||
last_p = 0
|
||||
while invalid_brackets:
|
||||
bracket = invalid_brackets.pop(0)
|
||||
p = invalid_at.pop(0)
|
||||
ret += prompt[last_p:p]
|
||||
last_p = p+1
|
||||
ret += prompt[last_p:]
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def space_bracekts(prompt: str):
|
||||
"""Space adjacent closing-opening brackets.
|
||||
|
||||
e.g. ')(' -> '()'
|
||||
"""
|
||||
def helper(match: re.Match):
|
||||
# print(' '.join(match.groups()))
|
||||
return " ".join(match.groups())
|
||||
|
||||
# print(prompt)
|
||||
return re_opposing_brackets.sub(helper, prompt)
|
||||
|
||||
|
||||
def align_alternating(prompt: str):
|
||||
"""Push alternating symbol | together with words.
|
||||
|
||||
e.g.
|
||||
'a |b' -> 'a|b'
|
||||
"""
|
||||
def helper(match: re.Match):
|
||||
return match.group(1)
|
||||
|
||||
|
|
@ -333,76 +457,28 @@ def calculate_weight(d: str, *, is_square_brackets: bool):
|
|||
return 1 / 1.1 ** int(d) if is_square_brackets else 1 * 1.1 ** int(d)
|
||||
|
||||
|
||||
def get_weight(
|
||||
prompt: str,
|
||||
map_gradient: list,
|
||||
map_depth: list,
|
||||
map_brackets: list,
|
||||
pos: int,
|
||||
ctv: int,
|
||||
gradient_search: str,
|
||||
is_square_brackets: bool = False,
|
||||
):
|
||||
"""Returns 0 if bracket was recognized as prompt editing, alternation, or composable."""
|
||||
# CURRENTLY DOES NOT TAKE INTO ACCOUNT COMPOSABLE?? DO WE EVEN NEED TO?
|
||||
# E.G. [a AND B :1.2] == (a AND B:1.1) != (a AND B:1.1) ????
|
||||
while pos + ctv <= len(prompt):
|
||||
if ctv == 0:
|
||||
return prompt, 0, 1
|
||||
a, b = pos, pos + ctv
|
||||
if prompt[a] in ":|" and is_square_brackets:
|
||||
if map_depth[-2] == map_depth[a]:
|
||||
return prompt, 0, 1
|
||||
if map_depth[a] in gradient_search:
|
||||
gradient_search = gradient_search.replace(map_depth[a], "")
|
||||
ctv -= 1
|
||||
elif map_gradient[a:b] == "v" * ctv and map_depth[a - 1 : b] == gradient_search:
|
||||
return a, calculate_weight(ctv, is_square_brackets=is_square_brackets), ctv
|
||||
elif "v" == map_gradient[a] and map_depth[a - 1 : b - 1] in gradient_search:
|
||||
narrowing = map_gradient[a:b].count("v")
|
||||
gradient_search = gradient_search[narrowing:]
|
||||
ctv -= 1
|
||||
pos += 1
|
||||
|
||||
msg = f"Somehow weight index searching has gone outside of prompt length with prompt: {prompt}"
|
||||
raise Exception(msg)
|
||||
|
||||
|
||||
def space_to_underscore(prompt: str, conv_space_to_underscore = True):
|
||||
# We need to look ahead and ignore any spaces/underscores within network tokens
|
||||
# INPUT <lora:chicken butt>, multiple subjects
|
||||
# OUTPUT <lora:chicken butt>, multiple_subjects
|
||||
|
||||
def space_to_underscore(prompt: str, *, opposite = True):
|
||||
"""Replace space with underscore or vice versa.
|
||||
|
||||
It's a but funky right now because it uses the tokenizer to chunk for sub.
|
||||
Currently(ish), the tokenizer does not strip whitespace, so any existing
|
||||
'foo, bar' is split into ('foo', ' bar'), and will result in 'foo,_bar'.
|
||||
|
||||
This has been patched by requiring the match to be surrounded with a
|
||||
character, but I'm sure there's better solutions. It will work for now.
|
||||
"""
|
||||
match = (
|
||||
r"(?<!BREAK) +(?!BREAK|[^<]*>)"
|
||||
if conv_space_to_underscore
|
||||
else r"(?<!BREAK|_)_(?!_|BREAK|[^<]*>)"
|
||||
r"(?<!BREAK)(?<=\w) +(?=\w)(?!BREAK)(?![^<]*>)"
|
||||
if opposite
|
||||
else r"(?<!BREAK)(?<=\w)_+(?=\w)(?!BREAK)(?![^<]*>)"
|
||||
)
|
||||
replace = "_" if conv_space_to_underscore else " "
|
||||
replace = "_" if opposite else " "
|
||||
|
||||
tokens: str = tokenize(prompt)
|
||||
print(tokens)
|
||||
|
||||
return ",".join(map(lambda t: re.sub(match, replace, t), tokens))
|
||||
|
||||
|
||||
def escape_bracket_index(token, symbols, start_index=0):
|
||||
# Given a token and a set of open bracket symbols, find the index in which that character
|
||||
# escapes the given bracketing such that depth = 0.
|
||||
token_length = len(token)
|
||||
open = symbols
|
||||
close = ""
|
||||
for s in symbols:
|
||||
close += brackets_closing[brackets_opening.index(s)]
|
||||
|
||||
i = start_index
|
||||
d = 0
|
||||
while i < token_length - 1:
|
||||
if token[i] in open:
|
||||
d += 1
|
||||
if token[i] in close:
|
||||
d -= 1
|
||||
if d == 0:
|
||||
return i
|
||||
i += 1
|
||||
|
||||
return i
|
||||
|
|
|
|||
|
|
@ -0,0 +1,128 @@
|
|||
"""Unit testing for prompt transformation pipeline."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts import prompt_formatting_pipeline as pipeline
|
||||
|
||||
|
||||
def test_get_bracket_closing():
|
||||
# Test for each opening bracket
|
||||
assert pipeline.get_bracket_closing('(') == ')'
|
||||
assert pipeline.get_bracket_closing('[') == ']'
|
||||
assert pipeline.get_bracket_closing('{') == '}'
|
||||
assert pipeline.get_bracket_closing('<') == '>'
|
||||
|
||||
# # Test for invalid input (not an opening bracket)
|
||||
# with pytest.raises(ValueError):
|
||||
# pipeline.get_bracket_closing('a')
|
||||
#
|
||||
# # Test for empty string (should raise an error)
|
||||
# with pytest.raises(ValueError):
|
||||
# pipeline.get_bracket_closing('')
|
||||
#
|
||||
# # Test for input that's not a single character
|
||||
# with pytest.raises(TypeError):
|
||||
# pipeline.get_bracket_closing('(())')
|
||||
|
||||
def test_get_bracket_opening():
|
||||
# Test for each closing bracket
|
||||
assert pipeline.get_bracket_opening(')') == '('
|
||||
assert pipeline.get_bracket_opening(']') == '['
|
||||
assert pipeline.get_bracket_opening('}') == '{'
|
||||
assert pipeline.get_bracket_opening('>') == '<'
|
||||
|
||||
# # Test for invalid input (not a closing bracket)
|
||||
# with pytest.raises(ValueError):
|
||||
# pipeline.get_bracket_opening('a')
|
||||
#
|
||||
# # Test for empty string (should raise an error)
|
||||
# with pytest.raises(ValueError):
|
||||
# pipeline.get_bracket_opening('')
|
||||
#
|
||||
# # Test for input that's not a single character
|
||||
# with pytest.raises(TypeError):
|
||||
# pipeline.get_bracket_opening('()')
|
||||
|
||||
def test_normalize_characters():
|
||||
assert pipeline.normalize_characters('アイウエオ') == 'アイウエオ' # Full-width to half-width
|
||||
assert pipeline.normalize_characters('𝓣𝓮𝓼𝓽') == 'Test' # Fraktur to regular
|
||||
assert pipeline.normalize_characters('abc') == 'abc' # No change
|
||||
assert pipeline.normalize_characters('Hello, 世界!') == 'Hello, 世界!' # Mixed characters
|
||||
|
||||
def test_tokenize():
|
||||
assert pipeline.tokenize('a,b,c') == ['a', 'b', 'c']
|
||||
assert pipeline.tokenize('1,2,3,4') == ['1', '2', '3', '4']
|
||||
assert pipeline.tokenize('apple,,banana') == ['apple', '', 'banana']
|
||||
assert pipeline.tokenize('hello') == ['hello']
|
||||
assert pipeline.tokenize('a,,b,,c') == ['a', '', 'b', '', 'c']
|
||||
|
||||
|
||||
def test_align_brackets():
|
||||
assert pipeline.align_brackets('( foo)') == '(foo)'
|
||||
assert pipeline.align_brackets('[ bar ]') == '[bar]'
|
||||
assert pipeline.align_brackets('{ test }') == '{test}'
|
||||
assert pipeline.align_brackets('< example >') == '<example>'
|
||||
assert pipeline.align_brackets('( [ { < content > } ] )') == '([{<content>}])'
|
||||
|
||||
def test_space_and():
|
||||
assert pipeline.space_and('a AND b') == 'a AND b'
|
||||
assert pipeline.space_and('foo ANDbar') == 'foo AND bar'
|
||||
assert pipeline.space_and('hello AND world') == 'hello AND world'
|
||||
assert pipeline.space_and('test ANDexample') == 'test AND example'
|
||||
assert pipeline.space_and('a AND b AND c') == 'a AND b AND c'
|
||||
assert pipeline.space_and('aANDbANDc') == 'a AND b AND c'
|
||||
|
||||
def test_align_colons():
|
||||
assert pipeline.align_colons('key: value') == 'key:value'
|
||||
assert pipeline.align_colons('foo : bar') == 'foo:bar'
|
||||
assert pipeline.align_colons('test: example') == 'test:example'
|
||||
assert pipeline.align_colons('name: John AND age: 30') == 'name:John AND age:30'
|
||||
assert pipeline.align_colons('foo bar:1.0 AND zee') == 'foo bar:1.0 AND zee'
|
||||
|
||||
def test_align_commas():
|
||||
assert pipeline.align_commas('a, b, c') == 'a, b, c'
|
||||
assert pipeline.align_commas(' foo , bar , baz ') == 'foo, bar, baz'
|
||||
assert pipeline.align_commas('test,example') == 'test, example'
|
||||
assert pipeline.align_commas(' item1, item2 , item3 ') == 'item1, item2, item3'
|
||||
assert pipeline.align_commas(' , a , b , c , ') == 'a, b, c'
|
||||
|
||||
def test_remove_mismatched_brackets():
|
||||
assert pipeline.remove_mismatched_brackets('(a[b]c)') == '(a[b]c)'
|
||||
assert pipeline.remove_mismatched_brackets('a(b)c') == 'a(b)c'
|
||||
assert pipeline.remove_mismatched_brackets('a(b]c') == 'abc'
|
||||
assert pipeline.remove_mismatched_brackets('[(a+b)]') == '[(a+b)]'
|
||||
assert pipeline.remove_mismatched_brackets('a{b[c}d]') == 'abcd'
|
||||
|
||||
def test_space_bracekts():
|
||||
assert pipeline.space_bracekts(')(') == ') ('
|
||||
assert pipeline.space_bracekts('][}{') == '] [} {'
|
||||
assert pipeline.space_bracekts('foo(bar)baz') == 'foo(bar)baz'
|
||||
assert pipeline.space_bracekts('a(b)c[d]e{f}g') == 'a(b)c[d]e{f}g'
|
||||
assert pipeline.space_bracekts(')a[b]{c}') == ')a[b] {c}'
|
||||
|
||||
def test_align_alternating():
|
||||
assert pipeline.align_alternating('a |b') == 'a|b'
|
||||
assert pipeline.align_alternating('foo |bar |baz') == 'foo|bar|baz'
|
||||
assert pipeline.align_alternating('test | example') == 'test|example'
|
||||
assert pipeline.align_alternating('hello | world') == 'hello|world'
|
||||
assert pipeline.align_alternating('a | b | c') == 'a|b|c'
|
||||
|
||||
def test_bracket_to_weights():
|
||||
assert pipeline.bracket_to_weights('(a)') == '(a:1.10)'
|
||||
assert pipeline.bracket_to_weights('((a))') == '(a:1.21)'
|
||||
assert pipeline.bracket_to_weights('((a, b))') == '(a, b:1.21)'
|
||||
assert pipeline.bracket_to_weights('(a, (b))') == '(a, (b:1.10):1.10)'
|
||||
assert pipeline.bracket_to_weights('((a), b)') == '((a:1.10), b:1.10)'
|
||||
assert pipeline.bracket_to_weights('((a), ((b)))') == '((a:1.10), (b:1.21):1.10)'
|
||||
|
||||
def test_space_to_underscore():
|
||||
assert pipeline.space_to_underscore('<lora:chicken butt>, multiple subjects') == '<lora:chicken butt>, multiple_subjects'
|
||||
assert pipeline.space_to_underscore('one two three') == 'one_two_three'
|
||||
assert pipeline.space_to_underscore('this is a test') == 'this_is_a_test'
|
||||
assert pipeline.space_to_underscore('<embed:foo bar>, baz') == '<embed:foo bar>, baz'
|
||||
assert pipeline.space_to_underscore('some_var_name', opposite=False) == 'some var name'
|
||||
|
||||
Loading…
Reference in New Issue