sd-webui-neutral-prompt/test/perp_parser/basic_test.py

114 lines
5.5 KiB
Python

import unittest
import pathlib
import sys
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
from lib_neutral_prompt import neutral_prompt_parser
class TestPromptParser(unittest.TestCase):
def setUp(self):
self.simple_prompt = neutral_prompt_parser.parse_root("hello :1.0")
self.and_prompt = neutral_prompt_parser.parse_root("hello AND goodbye :2.0")
self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0]")
self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0]")
self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP [welcome :3.0]]")
self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT [welcome :3.0]]")
self.invalid_weight = neutral_prompt_parser.parse_root("hello :not_a_float")
def test_simple_prompt_child_count(self):
self.assertEqual(len(self.simple_prompt.children), 1)
def test_simple_prompt_child_weight(self):
self.assertEqual(self.simple_prompt.children[0].weight, 1.0)
def test_simple_prompt_child_prompt(self):
self.assertEqual(self.simple_prompt.children[0].prompt, "hello ")
def test_and_prompt_child_count(self):
self.assertEqual(len(self.and_prompt.children), 2)
def test_and_prompt_child_weights_and_prompts(self):
self.assertEqual(self.and_prompt.children[0].weight, 1.0)
self.assertEqual(self.and_prompt.children[0].prompt, "hello ")
self.assertEqual(self.and_prompt.children[1].weight, 2.0)
self.assertEqual(self.and_prompt.children[1].prompt, " goodbye ")
def test_and_perp_prompt_child_count(self):
self.assertEqual(len(self.and_perp_prompt.children), 2)
def test_and_perp_prompt_child_types(self):
self.assertIsInstance(self.and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt)
def test_and_perp_prompt_nested_child(self):
nested_child = self.and_perp_prompt.children[1].children[0]
self.assertEqual(nested_child.weight, 2.0)
self.assertEqual(nested_child.prompt, "goodbye ")
def test_nested_and_perp_prompt_child_count(self):
self.assertEqual(len(self.nested_and_perp_prompt.children), 2)
def test_nested_and_perp_prompt_child_types(self):
self.assertIsInstance(self.nested_and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
self.assertIsInstance(self.nested_and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt)
def test_nested_and_perp_prompt_nested_child_types(self):
nested_child = self.nested_and_perp_prompt.children[1].children[0]
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
nested_child = self.nested_and_perp_prompt.children[1].children[1]
self.assertIsInstance(nested_child, neutral_prompt_parser.CompositePrompt)
def test_nested_and_perp_prompt_nested_child(self):
nested_child = self.nested_and_perp_prompt.children[1].children[1].children[0]
self.assertEqual(nested_child.weight, 3.0)
self.assertEqual(nested_child.prompt, "welcome ")
def test_invalid_weight_child_count(self):
self.assertEqual(len(self.invalid_weight.children), 1)
def test_invalid_weight_child_weight(self):
self.assertEqual(self.invalid_weight.children[0].weight, 1.0)
def test_invalid_weight_child_prompt(self):
self.assertEqual(self.invalid_weight.children[0].prompt, "hello :not_a_float")
def test_and_salt_prompt_child_count(self):
self.assertEqual(len(self.and_salt_prompt.children), 2)
def test_and_salt_prompt_child_types(self):
self.assertIsInstance(self.and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt)
def test_and_salt_prompt_nested_child(self):
nested_child = self.and_salt_prompt.children[1].children[0]
self.assertEqual(nested_child.weight, 2.0)
self.assertEqual(nested_child.prompt, "goodbye ")
def test_nested_and_salt_prompt_child_count(self):
self.assertEqual(len(self.nested_and_salt_prompt.children), 2)
def test_nested_and_salt_prompt_child_types(self):
self.assertIsInstance(self.nested_and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
self.assertIsInstance(self.nested_and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt)
def test_nested_and_salt_prompt_nested_child_types(self):
nested_child = self.nested_and_salt_prompt.children[1].children[0]
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
nested_child = self.nested_and_salt_prompt.children[1].children[1]
self.assertIsInstance(nested_child, neutral_prompt_parser.CompositePrompt)
def test_nested_and_salt_prompt_nested_child(self):
nested_child = self.nested_and_salt_prompt.children[1].children[1].children[0]
self.assertEqual(nested_child.weight, 3.0)
self.assertEqual(nested_child.prompt, "welcome ")
def test_start_with_prompt_editing(self):
prompt = "[(long shot:1.2):0.1] detail.."
res = neutral_prompt_parser.parse_root(prompt)
self.assertEqual(res.children[0].weight, 1.0)
self.assertEqual(res.children[0].prompt, prompt)
if __name__ == '__main__':
unittest.main()