nest 2 `AND_*` keywords or more inside `[...]` to group (#57)
* nest 2 to group * fix some corner cases with accidental double brackets * refactor; fix cfg rescale mean * tests * more test cases * more test cases --------- Co-authored-by: ljleb <set>pull/60/head
parent
ccc2aedea5
commit
e395a11eb2
|
|
@ -23,10 +23,10 @@ def combine_denoised_hijack(
|
|||
|
||||
for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
|
||||
args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices)
|
||||
cond_delta = prompt.accept(CondDeltaChildVisitor(), args, 0)
|
||||
aux_cond_delta = prompt.accept(AuxCondDeltaChildVisitor(), args, cond_delta, 0)
|
||||
cond_delta = prompt.accept(CondDeltaVisitor(), args, 0)
|
||||
aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0)
|
||||
cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale
|
||||
denoised[batch_i] = cfg_cond * get_cfg_rescale_factor(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta)
|
||||
denoised[batch_i] = cfg_rescale(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta)
|
||||
|
||||
return denoised
|
||||
|
||||
|
|
@ -41,22 +41,27 @@ def get_webui_denoised(
|
|||
uncond = x_out[-text_uncond.shape[0]:]
|
||||
sliced_batch_x_out = []
|
||||
sliced_batch_cond_indices = []
|
||||
index_in = 0
|
||||
|
||||
for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
|
||||
args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices)
|
||||
sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, len(sliced_batch_x_out))
|
||||
sliced_batch_cond_indices.append(sliced_cond_indices)
|
||||
sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, index_in, len(sliced_batch_x_out))
|
||||
if sliced_cond_indices:
|
||||
sliced_batch_cond_indices.append(sliced_cond_indices)
|
||||
sliced_batch_x_out.extend(sliced_x_out)
|
||||
index_in += prompt.accept(neutral_prompt_parser.FlatSizeVisitor())
|
||||
|
||||
sliced_batch_x_out += list(uncond)
|
||||
sliced_batch_x_out = torch.stack(sliced_batch_x_out, dim=0)
|
||||
sliced_batch_cond_indices = [il for il in sliced_batch_cond_indices if il]
|
||||
return original_function(sliced_batch_x_out, sliced_batch_cond_indices, text_uncond, cond_scale)
|
||||
|
||||
|
||||
def get_cfg_rescale_factor(cfg_cond, cond):
|
||||
def cfg_rescale(cfg_cond, cond):
|
||||
global_state.apply_and_clear_cfg_rescale_override()
|
||||
return global_state.cfg_rescale * (torch.std(cond) / torch.std(cfg_cond) - 1) + 1
|
||||
cfg_cond_mean = cfg_cond.mean()
|
||||
cfg_resacle_mean = (1 - global_state.cfg_rescale) * cfg_cond_mean + global_state.cfg_rescale * cond.mean()
|
||||
cfg_rescale_factor = global_state.cfg_rescale * (cond.std() / cfg_cond.std() - 1) + 1
|
||||
return cfg_resacle_mean + (cfg_cond - cfg_cond_mean) * cfg_rescale_factor
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
@ -68,67 +73,36 @@ class CombineDenoiseArgs:
|
|||
|
||||
@dataclasses.dataclass
|
||||
class GatherWebuiCondsVisitor:
|
||||
def visit_leaf_prompt(self, *args, **kwargs) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
|
||||
return [], []
|
||||
def visit_leaf_prompt(
|
||||
self,
|
||||
that: neutral_prompt_parser.CompositePrompt,
|
||||
args: CombineDenoiseArgs,
|
||||
index_in: int,
|
||||
index_out: int,
|
||||
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
|
||||
return [args.x_out[args.cond_indices[index_in][0]]], [(index_out, args.cond_indices[index_in][1])]
|
||||
|
||||
def visit_composite_prompt(
|
||||
self,
|
||||
that: neutral_prompt_parser.CompositePrompt,
|
||||
args: CombineDenoiseArgs,
|
||||
index_offset: int,
|
||||
index_in: int,
|
||||
index_out: int,
|
||||
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
|
||||
sliced_x_out = []
|
||||
sliced_cond_indices = []
|
||||
|
||||
index_in = 0
|
||||
for child in that.children:
|
||||
index_out = index_offset + len(sliced_x_out)
|
||||
child_x_out, child_cond_indices = child.accept(GatherWebuiCondsVisitor.SingleCondVisitor(), args.x_out, args.cond_indices[index_in], index_out)
|
||||
sliced_x_out.extend(child_x_out)
|
||||
sliced_cond_indices.extend(child_cond_indices)
|
||||
if child.conciliation is None:
|
||||
index_offset = index_out + len(sliced_x_out)
|
||||
child_x_out, child_cond_indices = child.accept(GatherWebuiCondsVisitor(), args, index_in, index_offset)
|
||||
sliced_x_out.extend(child_x_out)
|
||||
sliced_cond_indices.extend(child_cond_indices)
|
||||
|
||||
index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
||||
|
||||
return sliced_x_out, sliced_cond_indices
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SingleCondVisitor:
|
||||
def visit_leaf_prompt(
|
||||
self,
|
||||
that: neutral_prompt_parser.LeafPrompt,
|
||||
x_out: torch.Tensor,
|
||||
cond_info: Tuple[int, float],
|
||||
index: int,
|
||||
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
|
||||
return [x_out[cond_info[0]]], [(index, cond_info[1])]
|
||||
|
||||
def visit_composite_prompt(self, *args, **kwargs) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
|
||||
return [], []
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CondDeltaChildVisitor:
|
||||
def visit_leaf_prompt(
|
||||
self,
|
||||
that: neutral_prompt_parser.LeafPrompt,
|
||||
args: CombineDenoiseArgs,
|
||||
index: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.zeros_like(args.x_out[0])
|
||||
|
||||
def visit_composite_prompt(
|
||||
self,
|
||||
that: neutral_prompt_parser.CompositePrompt,
|
||||
args: CombineDenoiseArgs,
|
||||
index: int,
|
||||
) -> torch.Tensor:
|
||||
cond_delta = torch.zeros_like(args.x_out[0])
|
||||
|
||||
for child in that.children:
|
||||
cond_delta += child.weight * child.accept(CondDeltaVisitor(), args, index)
|
||||
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
||||
|
||||
return cond_delta
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CondDeltaVisitor:
|
||||
|
|
@ -143,8 +117,9 @@ class CondDeltaVisitor:
|
|||
console_warn(f'''
|
||||
An unexpected noise weight was encountered at prompt #{index}
|
||||
Expected :{that.weight}, but got :{cond_info[1]}
|
||||
This is likely due to another extension also monkey patching the webui noise blending function
|
||||
Please open a github issue so that the conflict can be resolved
|
||||
This is likely due to another extension also monkey patching the webui `combine_denoised` function
|
||||
Please open a bug report here so that the conflict can be resolved:
|
||||
https://github.com/ljleb/sd-webui-neutral-prompt/issues
|
||||
''')
|
||||
|
||||
return args.x_out[cond_info[0]] - args.uncond
|
||||
|
|
@ -157,17 +132,19 @@ class CondDeltaVisitor:
|
|||
) -> torch.Tensor:
|
||||
cond_delta = torch.zeros_like(args.x_out[0])
|
||||
|
||||
if that.conciliation is None:
|
||||
for child in that.children:
|
||||
child_cond_delta = child.accept(CondDeltaChildVisitor(), args, index)
|
||||
child_cond_delta += child.accept(AuxCondDeltaChildVisitor(), args, child_cond_delta, index)
|
||||
for child in that.children:
|
||||
if child.conciliation is None:
|
||||
child_cond_delta = child.accept(CondDeltaVisitor(), args, index)
|
||||
child_cond_delta += child.accept(AuxCondDeltaVisitor(), args, child_cond_delta, index)
|
||||
cond_delta += child.weight * child_cond_delta
|
||||
|
||||
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
||||
|
||||
return cond_delta
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AuxCondDeltaChildVisitor:
|
||||
class AuxCondDeltaVisitor:
|
||||
def visit_leaf_prompt(
|
||||
self,
|
||||
that: neutral_prompt_parser.LeafPrompt,
|
||||
|
|
@ -188,9 +165,10 @@ class AuxCondDeltaChildVisitor:
|
|||
salient_cond_deltas = []
|
||||
|
||||
for child in that.children:
|
||||
child_cond_delta = child.accept(CondDeltaChildVisitor(), args, index)
|
||||
child_cond_delta += child.accept(self, args, child_cond_delta, index)
|
||||
if isinstance(child, neutral_prompt_parser.CompositePrompt):
|
||||
if child.conciliation is not None:
|
||||
child_cond_delta = child.accept(CondDeltaVisitor(), args, index)
|
||||
child_cond_delta += child.accept(AuxCondDeltaVisitor(), args, child_cond_delta, index)
|
||||
|
||||
if child.conciliation == neutral_prompt_parser.ConciliationStrategy.PERPENDICULAR:
|
||||
aux_cond_delta += child.weight * get_perpendicular_component(cond_delta, child_cond_delta)
|
||||
elif child.conciliation == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK:
|
||||
|
|
@ -221,12 +199,12 @@ def salient_blend(normal: torch.Tensor, vectors: List[Tuple[torch.Tensor, float]
|
|||
The blended result combines `normal` and vector information in salient regions.
|
||||
"""
|
||||
|
||||
salience_maps = [get_salience(normal)] + [get_salience(vector) for vector, weight in vectors]
|
||||
salience_maps = [get_salience(normal)] + [get_salience(vector) for vector, _ in vectors]
|
||||
mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
|
||||
|
||||
result = torch.zeros_like(normal)
|
||||
for mask_i, (vector, weight) in enumerate(vectors, start=1):
|
||||
vector_mask = ((mask == mask_i).float())
|
||||
vector_mask = (mask == mask_i).float()
|
||||
result += weight * vector_mask * (vector - normal)
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -5,23 +5,6 @@ from enum import Enum
|
|||
from typing import List, Tuple, Any, Optional
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PromptExpr(abc.ABC):
|
||||
weight: float
|
||||
|
||||
@abc.abstractmethod
|
||||
def accept(self, visitor, *args, **kwargs) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LeafPrompt(PromptExpr):
|
||||
prompt: str
|
||||
|
||||
def accept(self, visitor, *args, **kwargs):
|
||||
return visitor.visit_leaf_prompt(self, *args, **kwargs)
|
||||
|
||||
|
||||
class PromptKeyword(Enum):
|
||||
AND = 'AND'
|
||||
AND_PERP = 'AND_PERP'
|
||||
|
|
@ -41,10 +24,27 @@ class ConciliationStrategy(Enum):
|
|||
conciliation_strategies = [e.value for e in ConciliationStrategy]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PromptExpr(abc.ABC):
|
||||
weight: float
|
||||
conciliation: Optional[ConciliationStrategy]
|
||||
|
||||
@abc.abstractmethod
|
||||
def accept(self, visitor, *args, **kwargs) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LeafPrompt(PromptExpr):
|
||||
prompt: str
|
||||
|
||||
def accept(self, visitor, *args, **kwargs):
|
||||
return visitor.visit_leaf_prompt(self, *args, **kwargs)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompositePrompt(PromptExpr):
|
||||
children: List[PromptExpr]
|
||||
conciliation: Optional[ConciliationStrategy]
|
||||
|
||||
def accept(self, visitor, *args, **kwargs):
|
||||
return visitor.visit_composite_prompt(self, *args, **kwargs)
|
||||
|
|
@ -61,57 +61,53 @@ class FlatSizeVisitor:
|
|||
def parse_root(string: str) -> CompositePrompt:
|
||||
tokens = tokenize(string)
|
||||
prompts = parse_prompts(tokens)
|
||||
return CompositePrompt(1., prompts, None)
|
||||
return CompositePrompt(1., None, prompts)
|
||||
|
||||
|
||||
def parse_prompts(tokens: List[str]) -> List[PromptExpr]:
|
||||
prompts = [parse_prompt(tokens, first=True)]
|
||||
def parse_prompts(tokens: List[str], *, nested: bool = False) -> List[PromptExpr]:
|
||||
prompts = [parse_prompt(tokens, first=True, nested=nested)]
|
||||
while tokens:
|
||||
if tokens[0] in [']']:
|
||||
if nested and tokens[0] in [']']:
|
||||
break
|
||||
|
||||
prompts.append(parse_prompt(tokens, first=False))
|
||||
prompts.append(parse_prompt(tokens, first=False, nested=nested))
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
def parse_prompt(tokens: List[str], *, first: bool) -> PromptExpr:
|
||||
if first:
|
||||
prompt_type = PromptKeyword.AND.value
|
||||
else:
|
||||
assert tokens[0] in prompt_keywords
|
||||
def parse_prompt(tokens: List[str], *, first: bool, nested: bool = False) -> PromptExpr:
|
||||
if not first and tokens[0] in prompt_keywords:
|
||||
prompt_type = tokens.pop(0)
|
||||
else:
|
||||
prompt_type = PromptKeyword.AND.value
|
||||
|
||||
tokens_copy = tokens.copy()
|
||||
if tokens_copy and tokens_copy[0] == '[':
|
||||
tokens_copy.pop(0)
|
||||
prompts = parse_prompts(tokens_copy)
|
||||
if tokens_copy:
|
||||
assert tokens_copy.pop(0) == ']'
|
||||
if not tokens_copy or tokens_copy[0] in prompt_keywords + [']']:
|
||||
tokens[:] = tokens_copy
|
||||
weight = parse_weight(tokens)
|
||||
conciliation = ConciliationStrategy(prompt_type) if prompt_type in conciliation_strategies else None
|
||||
return CompositePrompt(weight, prompts, conciliation)
|
||||
tokens_copy = tokens.copy()
|
||||
if tokens_copy and tokens_copy[0] == '[':
|
||||
tokens_copy.pop(0)
|
||||
prompts = parse_prompts(tokens_copy, nested=True)
|
||||
if tokens_copy:
|
||||
assert tokens_copy.pop(0) == ']'
|
||||
if len(prompts) > 1:
|
||||
tokens[:] = tokens_copy
|
||||
weight = parse_weight(tokens)
|
||||
conciliation = ConciliationStrategy(prompt_type) if prompt_type in conciliation_strategies else None
|
||||
return CompositePrompt(weight, conciliation, prompts)
|
||||
|
||||
prompt_text, weight = parse_prompt_text(tokens)
|
||||
prompt = LeafPrompt(weight, prompt_text)
|
||||
if prompt_type in conciliation_strategies:
|
||||
prompt.weight = 1.
|
||||
prompt = CompositePrompt(weight, [prompt], ConciliationStrategy(prompt_type))
|
||||
|
||||
return prompt
|
||||
prompt_text, weight = parse_prompt_text(tokens, nested=nested)
|
||||
return LeafPrompt(weight, ConciliationStrategy(prompt_type) if prompt_type in conciliation_strategies else None, prompt_text)
|
||||
|
||||
|
||||
def parse_prompt_text(tokens: List[str]) -> Tuple[str, float]:
|
||||
def parse_prompt_text(tokens: List[str], *, nested: bool = False) -> Tuple[str, float]:
|
||||
text = ''
|
||||
depth = 0
|
||||
weight = 1.
|
||||
while tokens:
|
||||
if tokens[0] == ']':
|
||||
if depth == 0:
|
||||
break
|
||||
depth -= 1
|
||||
if nested:
|
||||
break
|
||||
else:
|
||||
depth -= 1
|
||||
elif tokens[0] == '[':
|
||||
depth += 1
|
||||
elif tokens[0] == ':':
|
||||
|
|
@ -130,12 +126,9 @@ def parse_prompt_text(tokens: List[str]) -> Tuple[str, float]:
|
|||
|
||||
def parse_weight(tokens: List[str]) -> float:
|
||||
weight = 1.
|
||||
if tokens and tokens[0] == ':':
|
||||
if len(tokens) >= 2 and tokens[0] == ':' and is_float(tokens[1]):
|
||||
tokens.pop(0)
|
||||
if tokens:
|
||||
weight_str = tokens.pop(0)
|
||||
if is_float(weight_str):
|
||||
weight = float(weight_str)
|
||||
weight = float(tokens.pop(0))
|
||||
return weight
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ class TestPromptParser(unittest.TestCase):
|
|||
def setUp(self):
|
||||
self.simple_prompt = neutral_prompt_parser.parse_root("hello :1.0")
|
||||
self.and_prompt = neutral_prompt_parser.parse_root("hello AND goodbye :2.0")
|
||||
self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0]")
|
||||
self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0]")
|
||||
self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP [welcome :3.0]]")
|
||||
self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT [welcome :3.0]]")
|
||||
self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP goodbye :2.0")
|
||||
self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT goodbye :2.0")
|
||||
self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP welcome :3.0]")
|
||||
self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT welcome :3.0]")
|
||||
self.invalid_weight = neutral_prompt_parser.parse_root("hello :not_a_float")
|
||||
|
||||
def test_simple_prompt_child_count(self):
|
||||
|
|
@ -24,6 +24,15 @@ class TestPromptParser(unittest.TestCase):
|
|||
def test_simple_prompt_child_prompt(self):
|
||||
self.assertEqual(self.simple_prompt.children[0].prompt, "hello ")
|
||||
|
||||
def test_square_weight_prompt(self):
|
||||
prompt = "a [b c d e : f g h :1.5]"
|
||||
parsed = neutral_prompt_parser.parse_root(prompt)
|
||||
self.assertEqual(parsed.children[0].prompt, prompt)
|
||||
|
||||
composed_prompt = f"{prompt} AND_PERP other prompt"
|
||||
parsed = neutral_prompt_parser.parse_root(composed_prompt)
|
||||
self.assertEqual(parsed.children[0].prompt, prompt)
|
||||
|
||||
def test_and_prompt_child_count(self):
|
||||
self.assertEqual(len(self.and_prompt.children), 2)
|
||||
|
||||
|
|
@ -38,12 +47,12 @@ class TestPromptParser(unittest.TestCase):
|
|||
|
||||
def test_and_perp_prompt_child_types(self):
|
||||
self.assertIsInstance(self.and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
|
||||
self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt)
|
||||
self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.LeafPrompt)
|
||||
|
||||
def test_and_perp_prompt_nested_child(self):
|
||||
nested_child = self.and_perp_prompt.children[1].children[0]
|
||||
nested_child = self.and_perp_prompt.children[1]
|
||||
self.assertEqual(nested_child.weight, 2.0)
|
||||
self.assertEqual(nested_child.prompt, "goodbye ")
|
||||
self.assertEqual(nested_child.prompt.strip(), "goodbye")
|
||||
|
||||
def test_nested_and_perp_prompt_child_count(self):
|
||||
self.assertEqual(len(self.nested_and_perp_prompt.children), 2)
|
||||
|
|
@ -56,12 +65,12 @@ class TestPromptParser(unittest.TestCase):
|
|||
nested_child = self.nested_and_perp_prompt.children[1].children[0]
|
||||
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
||||
nested_child = self.nested_and_perp_prompt.children[1].children[1]
|
||||
self.assertIsInstance(nested_child, neutral_prompt_parser.CompositePrompt)
|
||||
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
||||
|
||||
def test_nested_and_perp_prompt_nested_child(self):
|
||||
nested_child = self.nested_and_perp_prompt.children[1].children[1].children[0]
|
||||
nested_child = self.nested_and_perp_prompt.children[1].children[1]
|
||||
self.assertEqual(nested_child.weight, 3.0)
|
||||
self.assertEqual(nested_child.prompt, "welcome ")
|
||||
self.assertEqual(nested_child.prompt.strip(), "welcome")
|
||||
|
||||
def test_invalid_weight_child_count(self):
|
||||
self.assertEqual(len(self.invalid_weight.children), 1)
|
||||
|
|
@ -77,12 +86,12 @@ class TestPromptParser(unittest.TestCase):
|
|||
|
||||
def test_and_salt_prompt_child_types(self):
|
||||
self.assertIsInstance(self.and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
|
||||
self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt)
|
||||
self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.LeafPrompt)
|
||||
|
||||
def test_and_salt_prompt_nested_child(self):
|
||||
nested_child = self.and_salt_prompt.children[1].children[0]
|
||||
nested_child = self.and_salt_prompt.children[1]
|
||||
self.assertEqual(nested_child.weight, 2.0)
|
||||
self.assertEqual(nested_child.prompt, "goodbye ")
|
||||
self.assertEqual(nested_child.prompt.strip(), "goodbye")
|
||||
|
||||
def test_nested_and_salt_prompt_child_count(self):
|
||||
self.assertEqual(len(self.nested_and_salt_prompt.children), 2)
|
||||
|
|
@ -95,12 +104,12 @@ class TestPromptParser(unittest.TestCase):
|
|||
nested_child = self.nested_and_salt_prompt.children[1].children[0]
|
||||
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
||||
nested_child = self.nested_and_salt_prompt.children[1].children[1]
|
||||
self.assertIsInstance(nested_child, neutral_prompt_parser.CompositePrompt)
|
||||
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
||||
|
||||
def test_nested_and_salt_prompt_nested_child(self):
|
||||
nested_child = self.nested_and_salt_prompt.children[1].children[1].children[0]
|
||||
nested_child = self.nested_and_salt_prompt.children[1].children[1]
|
||||
self.assertEqual(nested_child.weight, 3.0)
|
||||
self.assertEqual(nested_child.prompt, "welcome ")
|
||||
self.assertEqual(nested_child.prompt.strip(), "welcome")
|
||||
|
||||
def test_start_with_prompt_editing(self):
|
||||
prompt = "[(long shot:1.2):0.1] detail.."
|
||||
|
|
|
|||
|
|
@ -23,15 +23,42 @@ class TestMaliciousPromptParser(unittest.TestCase):
|
|||
self.assertEqual(result.children[0].weight, 1.0)
|
||||
self.assertEqual(result.children[1].weight, -2.0)
|
||||
|
||||
def test_debalanced_square_brackets(self):
|
||||
prompt = "a [ b " * 100
|
||||
result = self.parser.parse_root(prompt)
|
||||
self.assertEqual(result.children[0].prompt, prompt)
|
||||
|
||||
prompt = "a ] b " * 100
|
||||
result = self.parser.parse_root(prompt)
|
||||
self.assertEqual(result.children[0].prompt, prompt)
|
||||
|
||||
repeats = 10
|
||||
prompt = "a [ [ b AND c ] " * repeats
|
||||
result = self.parser.parse_root(prompt)
|
||||
self.assertEqual([x.prompt for x in result.children], ["a [[ b ", *[" c ] a [[ b "] * (repeats - 1), " c ]"])
|
||||
|
||||
repeats = 10
|
||||
prompt = "a [ b AND c ] ] " * repeats
|
||||
result = self.parser.parse_root(prompt)
|
||||
self.assertEqual([x.prompt for x in result.children], ["a [ b ", *[" c ]] a [ b "] * (repeats - 1), " c ]]"])
|
||||
|
||||
def test_erroneous_syntax(self):
|
||||
result = self.parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0")
|
||||
self.assertEqual(result.children[0].weight, 1.0)
|
||||
self.assertEqual(result.children[1].children[0].prompt, "goodbye ")
|
||||
self.assertEqual(result.children[1].children[0].weight, 2.0)
|
||||
self.assertEqual(result.children[1].prompt, "[goodbye ")
|
||||
self.assertEqual(result.children[1].weight, 2.0)
|
||||
|
||||
result = self.parser.parse_root("hello :1.0 AND_PERP goodbye :2.0]")
|
||||
self.assertEqual(result.children[0].weight, 1.0)
|
||||
self.assertEqual(result.children[1].children[0].prompt, " goodbye ")
|
||||
self.assertEqual(result.children[1].prompt, " goodbye ")
|
||||
|
||||
result = self.parser.parse_root("hello :1.0 AND_PERP goodbye] :2.0")
|
||||
self.assertEqual(result.children[1].prompt, " goodbye]")
|
||||
self.assertEqual(result.children[1].weight, 2.0)
|
||||
|
||||
result = self.parser.parse_root("hello :1.0 AND_PERP a [ goodbye :2.0")
|
||||
self.assertEqual(result.children[1].weight, 2.0)
|
||||
self.assertEqual(result.children[1].prompt, " a [ goodbye ")
|
||||
|
||||
result = self.parser.parse_root("hello :1.0 AND_PERP AND goodbye :2.0")
|
||||
self.assertEqual(result.children[0].weight, 1.0)
|
||||
|
|
@ -58,13 +85,13 @@ class TestMaliciousPromptParser(unittest.TestCase):
|
|||
self.assertIsInstance(result.children[1], neutral_prompt_parser.CompositePrompt)
|
||||
|
||||
def test_complex_nested_prompts(self):
|
||||
complex_prompt = "hello :1.0 AND goodbye :2.0 AND_PERP [welcome :3.0 AND farewell :4.0 AND_PERP [greetings :5.0]]"
|
||||
complex_prompt = "hello :1.0 AND goodbye :2.0 AND_PERP [welcome :3.0 AND farewell :4.0 AND_PERP greetings:5.0]"
|
||||
result = self.parser.parse_root(complex_prompt)
|
||||
self.assertEqual(result.children[0].weight, 1.0)
|
||||
self.assertEqual(result.children[1].weight, 2.0)
|
||||
self.assertEqual(result.children[2].children[0].weight, 3.0)
|
||||
self.assertEqual(result.children[2].children[1].weight, 4.0)
|
||||
self.assertEqual(result.children[2].children[2].children[0].weight, 5.0)
|
||||
self.assertEqual(result.children[2].children[2].weight, 5.0)
|
||||
|
||||
def test_string_with_random_characters(self):
|
||||
random_chars = "ASDFGHJKL:@#$/.,|}{><~`12[3]456AND_PERP7890"
|
||||
|
|
|
|||
Loading…
Reference in New Issue