sd_extension-prompt_formatter/tests/test_pipeline.py

131 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""Unit testing for prompt transformation pipeline."""
import sys
from pathlib import Path
from typing import TYPE_CHECKING
import pytest
from scripts import prompt_formatting_pipeline as pipeline
from scripts.prompt_formatting_definitions import UnderSpaceEnum
def test_get_bracket_closing():
# Test for each opening bracket
assert pipeline.get_bracket_closing('(') == ')'
assert pipeline.get_bracket_closing('[') == ']'
assert pipeline.get_bracket_closing('{') == '}'
assert pipeline.get_bracket_closing('<') == '>'
# # Test for invalid input (not an opening bracket)
# with pytest.raises(ValueError):
# pipeline.get_bracket_closing('a')
#
# # Test for empty string (should raise an error)
# with pytest.raises(ValueError):
# pipeline.get_bracket_closing('')
#
# # Test for input that's not a single character
# with pytest.raises(TypeError):
# pipeline.get_bracket_closing('(())')
def test_get_bracket_opening():
# Test for each closing bracket
assert pipeline.get_bracket_opening(')') == '('
assert pipeline.get_bracket_opening(']') == '['
assert pipeline.get_bracket_opening('}') == '{'
assert pipeline.get_bracket_opening('>') == '<'
# # Test for invalid input (not a closing bracket)
# with pytest.raises(ValueError):
# pipeline.get_bracket_opening('a')
#
# # Test for empty string (should raise an error)
# with pytest.raises(ValueError):
# pipeline.get_bracket_opening('')
#
# # Test for input that's not a single character
# with pytest.raises(TypeError):
# pipeline.get_bracket_opening('()')
def test_normalize_characters():
assert pipeline.normalize_characters('アイウエオ') == 'アイウエオ' # Full-width to half-width
assert pipeline.normalize_characters('𝓣𝓮𝓼𝓽') == 'Test' # Fraktur to regular
assert pipeline.normalize_characters('abc') == 'abc' # No change
assert pipeline.normalize_characters('Hello, 世界!') == 'Hello, 世界!' # Mixed characters
def test_tokenize():
assert pipeline.tokenize('a,b,c') == ['a', 'b', 'c']
assert pipeline.tokenize('1,2,3,4') == ['1', '2', '3', '4']
assert pipeline.tokenize('apple,,banana') == ['apple', '', 'banana']
assert pipeline.tokenize('hello') == ['hello']
assert pipeline.tokenize('a,,b,,c') == ['a', '', 'b', '', 'c']
def test_align_brackets():
assert pipeline.align_brackets('( foo)') == '(foo)'
assert pipeline.align_brackets('[ bar ]') == '[bar]'
assert pipeline.align_brackets('{ test }') == '{test}'
assert pipeline.align_brackets('< example >') == '<example>'
assert pipeline.align_brackets('( [ { < content > } ] )') == '([{<content>}])'
def test_space_and():
assert pipeline.space_and('a AND b') == 'a AND b'
assert pipeline.space_and('foo ANDbar') == 'foo AND bar'
assert pipeline.space_and('hello AND world') == 'hello AND world'
assert pipeline.space_and('test ANDexample') == 'test AND example'
assert pipeline.space_and('a AND b AND c') == 'a AND b AND c'
assert pipeline.space_and('aANDbANDc') == 'a AND b AND c'
def test_align_colons():
assert pipeline.align_colons('key: value') == 'key:value'
assert pipeline.align_colons('foo : bar') == 'foo:bar'
assert pipeline.align_colons('test: example') == 'test:example'
assert pipeline.align_colons('name: John AND age: 30') == 'name:John AND age:30'
assert pipeline.align_colons('foo bar:1.0 AND zee') == 'foo bar:1.0 AND zee'
def test_align_commas():
assert pipeline.align_commas('a, b, c') == 'a, b, c'
assert pipeline.align_commas(' foo , bar , baz ') == 'foo, bar, baz'
assert pipeline.align_commas('test,example') == 'test, example'
assert pipeline.align_commas(' item1, item2 , item3 ') == 'item1, item2, item3'
assert pipeline.align_commas(' , a , b , c , ') == 'a, b, c'
def test_remove_mismatched_brackets():
assert pipeline.remove_mismatched_brackets('(a[b]c)') == '(a[b]c)'
assert pipeline.remove_mismatched_brackets('a(b)c') == 'a(b)c'
assert pipeline.remove_mismatched_brackets('a(b]c') == 'abc'
assert pipeline.remove_mismatched_brackets('[(a+b)]') == '[(a+b)]'
assert pipeline.remove_mismatched_brackets('a{b[c}d]') == 'abcd'
def test_space_bracekts():
assert pipeline.space_bracekts(')(') == ') ('
assert pipeline.space_bracekts('][}{') == '] [} {'
assert pipeline.space_bracekts('foo(bar)baz') == 'foo(bar)baz'
assert pipeline.space_bracekts('a(b)c[d]e{f}g') == 'a(b)c[d]e{f}g'
assert pipeline.space_bracekts(')a[b]{c}') == ')a[b] {c}'
def test_align_alternating():
assert pipeline.align_alternating('a |b') == 'a|b'
assert pipeline.align_alternating('foo |bar |baz') == 'foo|bar|baz'
assert pipeline.align_alternating('test | example') == 'test|example'
assert pipeline.align_alternating('hello | world') == 'hello|world'
assert pipeline.align_alternating('a | b | c') == 'a|b|c'
def test_bracket_to_weights():
assert pipeline.bracket_to_weights('(a)') == '(a:1.10)'
assert pipeline.bracket_to_weights('((a))') == '(a:1.21)'
assert pipeline.bracket_to_weights('((a, b))') == '(a, b:1.21)'
assert pipeline.bracket_to_weights('(a, (b))') == '(a, (b:1.10):1.10)'
assert pipeline.bracket_to_weights('((a), b)') == '((a:1.10), b:1.10)'
assert pipeline.bracket_to_weights('((a), ((b)))') == '((a:1.10), (b:1.21):1.10)'
def test_space_to_underscore():
assert pipeline.space_to_underscore('<lora:chicken butt>, multiple subjects') == '<lora:chicken butt>, multiple subjects'
assert pipeline.space_to_underscore('one_two_three') == 'one two three'
assert pipeline.space_to_underscore('this is_a test') == 'this is a test'
assert pipeline.space_to_underscore('<embed:foo bar>, baz', UnderSpaceEnum.IGNORE) == '<embed:foo bar>, baz'
assert pipeline.space_to_underscore('one two three', UnderSpaceEnum.UNDERSCORE) == 'one_two_three'