nest 2 `AND_*` keywords or more inside `[...]` to group (#57)

* nest 2 to group

* fix some corner cases with accidental double brackets

* refactor; fix cfg rescale mean

* tests

* more test cases

* more test cases

---------

Co-authored-by: ljleb <set>
pull/60/head
ljleb 2024-01-25 21:18:59 -05:00 committed by GitHub
parent ccc2aedea5
commit e395a11eb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 149 additions and 142 deletions

View File

@ -23,10 +23,10 @@ def combine_denoised_hijack(
for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices)
cond_delta = prompt.accept(CondDeltaChildVisitor(), args, 0)
aux_cond_delta = prompt.accept(AuxCondDeltaChildVisitor(), args, cond_delta, 0)
cond_delta = prompt.accept(CondDeltaVisitor(), args, 0)
aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0)
cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale
denoised[batch_i] = cfg_cond * get_cfg_rescale_factor(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta)
denoised[batch_i] = cfg_rescale(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta)
return denoised
@ -41,22 +41,27 @@ def get_webui_denoised(
uncond = x_out[-text_uncond.shape[0]:]
sliced_batch_x_out = []
sliced_batch_cond_indices = []
index_in = 0
for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices)
sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, len(sliced_batch_x_out))
sliced_batch_cond_indices.append(sliced_cond_indices)
sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, index_in, len(sliced_batch_x_out))
if sliced_cond_indices:
sliced_batch_cond_indices.append(sliced_cond_indices)
sliced_batch_x_out.extend(sliced_x_out)
index_in += prompt.accept(neutral_prompt_parser.FlatSizeVisitor())
sliced_batch_x_out += list(uncond)
sliced_batch_x_out = torch.stack(sliced_batch_x_out, dim=0)
sliced_batch_cond_indices = [il for il in sliced_batch_cond_indices if il]
return original_function(sliced_batch_x_out, sliced_batch_cond_indices, text_uncond, cond_scale)
def get_cfg_rescale_factor(cfg_cond, cond):
def cfg_rescale(cfg_cond, cond):
global_state.apply_and_clear_cfg_rescale_override()
return global_state.cfg_rescale * (torch.std(cond) / torch.std(cfg_cond) - 1) + 1
cfg_cond_mean = cfg_cond.mean()
cfg_resacle_mean = (1 - global_state.cfg_rescale) * cfg_cond_mean + global_state.cfg_rescale * cond.mean()
cfg_rescale_factor = global_state.cfg_rescale * (cond.std() / cfg_cond.std() - 1) + 1
return cfg_resacle_mean + (cfg_cond - cfg_cond_mean) * cfg_rescale_factor
@dataclasses.dataclass
@ -68,67 +73,36 @@ class CombineDenoiseArgs:
@dataclasses.dataclass
class GatherWebuiCondsVisitor:
def visit_leaf_prompt(self, *args, **kwargs) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
return [], []
def visit_leaf_prompt(
self,
that: neutral_prompt_parser.CompositePrompt,
args: CombineDenoiseArgs,
index_in: int,
index_out: int,
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
return [args.x_out[args.cond_indices[index_in][0]]], [(index_out, args.cond_indices[index_in][1])]
def visit_composite_prompt(
self,
that: neutral_prompt_parser.CompositePrompt,
args: CombineDenoiseArgs,
index_offset: int,
index_in: int,
index_out: int,
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
sliced_x_out = []
sliced_cond_indices = []
index_in = 0
for child in that.children:
index_out = index_offset + len(sliced_x_out)
child_x_out, child_cond_indices = child.accept(GatherWebuiCondsVisitor.SingleCondVisitor(), args.x_out, args.cond_indices[index_in], index_out)
sliced_x_out.extend(child_x_out)
sliced_cond_indices.extend(child_cond_indices)
if child.conciliation is None:
index_offset = index_out + len(sliced_x_out)
child_x_out, child_cond_indices = child.accept(GatherWebuiCondsVisitor(), args, index_in, index_offset)
sliced_x_out.extend(child_x_out)
sliced_cond_indices.extend(child_cond_indices)
index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())
return sliced_x_out, sliced_cond_indices
@dataclasses.dataclass
class SingleCondVisitor:
def visit_leaf_prompt(
self,
that: neutral_prompt_parser.LeafPrompt,
x_out: torch.Tensor,
cond_info: Tuple[int, float],
index: int,
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
return [x_out[cond_info[0]]], [(index, cond_info[1])]
def visit_composite_prompt(self, *args, **kwargs) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
return [], []
@dataclasses.dataclass
class CondDeltaChildVisitor:
def visit_leaf_prompt(
self,
that: neutral_prompt_parser.LeafPrompt,
args: CombineDenoiseArgs,
index: int,
) -> torch.Tensor:
return torch.zeros_like(args.x_out[0])
def visit_composite_prompt(
self,
that: neutral_prompt_parser.CompositePrompt,
args: CombineDenoiseArgs,
index: int,
) -> torch.Tensor:
cond_delta = torch.zeros_like(args.x_out[0])
for child in that.children:
cond_delta += child.weight * child.accept(CondDeltaVisitor(), args, index)
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
return cond_delta
@dataclasses.dataclass
class CondDeltaVisitor:
@ -143,8 +117,9 @@ class CondDeltaVisitor:
console_warn(f'''
An unexpected noise weight was encountered at prompt #{index}
Expected :{that.weight}, but got :{cond_info[1]}
This is likely due to another extension also monkey patching the webui noise blending function
Please open a github issue so that the conflict can be resolved
This is likely due to another extension also monkey patching the webui `combine_denoised` function
Please open a bug report here so that the conflict can be resolved:
https://github.com/ljleb/sd-webui-neutral-prompt/issues
''')
return args.x_out[cond_info[0]] - args.uncond
@ -157,17 +132,19 @@ class CondDeltaVisitor:
) -> torch.Tensor:
cond_delta = torch.zeros_like(args.x_out[0])
if that.conciliation is None:
for child in that.children:
child_cond_delta = child.accept(CondDeltaChildVisitor(), args, index)
child_cond_delta += child.accept(AuxCondDeltaChildVisitor(), args, child_cond_delta, index)
for child in that.children:
if child.conciliation is None:
child_cond_delta = child.accept(CondDeltaVisitor(), args, index)
child_cond_delta += child.accept(AuxCondDeltaVisitor(), args, child_cond_delta, index)
cond_delta += child.weight * child_cond_delta
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
return cond_delta
@dataclasses.dataclass
class AuxCondDeltaChildVisitor:
class AuxCondDeltaVisitor:
def visit_leaf_prompt(
self,
that: neutral_prompt_parser.LeafPrompt,
@ -188,9 +165,10 @@ class AuxCondDeltaChildVisitor:
salient_cond_deltas = []
for child in that.children:
child_cond_delta = child.accept(CondDeltaChildVisitor(), args, index)
child_cond_delta += child.accept(self, args, child_cond_delta, index)
if isinstance(child, neutral_prompt_parser.CompositePrompt):
if child.conciliation is not None:
child_cond_delta = child.accept(CondDeltaVisitor(), args, index)
child_cond_delta += child.accept(AuxCondDeltaVisitor(), args, child_cond_delta, index)
if child.conciliation == neutral_prompt_parser.ConciliationStrategy.PERPENDICULAR:
aux_cond_delta += child.weight * get_perpendicular_component(cond_delta, child_cond_delta)
elif child.conciliation == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK:
@ -221,12 +199,12 @@ def salient_blend(normal: torch.Tensor, vectors: List[Tuple[torch.Tensor, float]
The blended result combines `normal` and vector information in salient regions.
"""
salience_maps = [get_salience(normal)] + [get_salience(vector) for vector, weight in vectors]
salience_maps = [get_salience(normal)] + [get_salience(vector) for vector, _ in vectors]
mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
result = torch.zeros_like(normal)
for mask_i, (vector, weight) in enumerate(vectors, start=1):
vector_mask = ((mask == mask_i).float())
vector_mask = (mask == mask_i).float()
result += weight * vector_mask * (vector - normal)
return result

View File

@ -5,23 +5,6 @@ from enum import Enum
from typing import List, Tuple, Any, Optional
@dataclasses.dataclass
class PromptExpr(abc.ABC):
weight: float
@abc.abstractmethod
def accept(self, visitor, *args, **kwargs) -> Any:
pass
@dataclasses.dataclass
class LeafPrompt(PromptExpr):
prompt: str
def accept(self, visitor, *args, **kwargs):
return visitor.visit_leaf_prompt(self, *args, **kwargs)
class PromptKeyword(Enum):
AND = 'AND'
AND_PERP = 'AND_PERP'
@ -41,10 +24,27 @@ class ConciliationStrategy(Enum):
conciliation_strategies = [e.value for e in ConciliationStrategy]
@dataclasses.dataclass
class PromptExpr(abc.ABC):
weight: float
conciliation: Optional[ConciliationStrategy]
@abc.abstractmethod
def accept(self, visitor, *args, **kwargs) -> Any:
pass
@dataclasses.dataclass
class LeafPrompt(PromptExpr):
prompt: str
def accept(self, visitor, *args, **kwargs):
return visitor.visit_leaf_prompt(self, *args, **kwargs)
@dataclasses.dataclass
class CompositePrompt(PromptExpr):
children: List[PromptExpr]
conciliation: Optional[ConciliationStrategy]
def accept(self, visitor, *args, **kwargs):
return visitor.visit_composite_prompt(self, *args, **kwargs)
@ -61,57 +61,53 @@ class FlatSizeVisitor:
def parse_root(string: str) -> CompositePrompt:
tokens = tokenize(string)
prompts = parse_prompts(tokens)
return CompositePrompt(1., prompts, None)
return CompositePrompt(1., None, prompts)
def parse_prompts(tokens: List[str]) -> List[PromptExpr]:
prompts = [parse_prompt(tokens, first=True)]
def parse_prompts(tokens: List[str], *, nested: bool = False) -> List[PromptExpr]:
prompts = [parse_prompt(tokens, first=True, nested=nested)]
while tokens:
if tokens[0] in [']']:
if nested and tokens[0] in [']']:
break
prompts.append(parse_prompt(tokens, first=False))
prompts.append(parse_prompt(tokens, first=False, nested=nested))
return prompts
def parse_prompt(tokens: List[str], *, first: bool) -> PromptExpr:
if first:
prompt_type = PromptKeyword.AND.value
else:
assert tokens[0] in prompt_keywords
def parse_prompt(tokens: List[str], *, first: bool, nested: bool = False) -> PromptExpr:
if not first and tokens[0] in prompt_keywords:
prompt_type = tokens.pop(0)
else:
prompt_type = PromptKeyword.AND.value
tokens_copy = tokens.copy()
if tokens_copy and tokens_copy[0] == '[':
tokens_copy.pop(0)
prompts = parse_prompts(tokens_copy)
if tokens_copy:
assert tokens_copy.pop(0) == ']'
if not tokens_copy or tokens_copy[0] in prompt_keywords + [']']:
tokens[:] = tokens_copy
weight = parse_weight(tokens)
conciliation = ConciliationStrategy(prompt_type) if prompt_type in conciliation_strategies else None
return CompositePrompt(weight, prompts, conciliation)
tokens_copy = tokens.copy()
if tokens_copy and tokens_copy[0] == '[':
tokens_copy.pop(0)
prompts = parse_prompts(tokens_copy, nested=True)
if tokens_copy:
assert tokens_copy.pop(0) == ']'
if len(prompts) > 1:
tokens[:] = tokens_copy
weight = parse_weight(tokens)
conciliation = ConciliationStrategy(prompt_type) if prompt_type in conciliation_strategies else None
return CompositePrompt(weight, conciliation, prompts)
prompt_text, weight = parse_prompt_text(tokens)
prompt = LeafPrompt(weight, prompt_text)
if prompt_type in conciliation_strategies:
prompt.weight = 1.
prompt = CompositePrompt(weight, [prompt], ConciliationStrategy(prompt_type))
return prompt
prompt_text, weight = parse_prompt_text(tokens, nested=nested)
return LeafPrompt(weight, ConciliationStrategy(prompt_type) if prompt_type in conciliation_strategies else None, prompt_text)
def parse_prompt_text(tokens: List[str]) -> Tuple[str, float]:
def parse_prompt_text(tokens: List[str], *, nested: bool = False) -> Tuple[str, float]:
text = ''
depth = 0
weight = 1.
while tokens:
if tokens[0] == ']':
if depth == 0:
break
depth -= 1
if nested:
break
else:
depth -= 1
elif tokens[0] == '[':
depth += 1
elif tokens[0] == ':':
@ -130,12 +126,9 @@ def parse_prompt_text(tokens: List[str]) -> Tuple[str, float]:
def parse_weight(tokens: List[str]) -> float:
weight = 1.
if tokens and tokens[0] == ':':
if len(tokens) >= 2 and tokens[0] == ':' and is_float(tokens[1]):
tokens.pop(0)
if tokens:
weight_str = tokens.pop(0)
if is_float(weight_str):
weight = float(weight_str)
weight = float(tokens.pop(0))
return weight

View File

@ -9,10 +9,10 @@ class TestPromptParser(unittest.TestCase):
def setUp(self):
self.simple_prompt = neutral_prompt_parser.parse_root("hello :1.0")
self.and_prompt = neutral_prompt_parser.parse_root("hello AND goodbye :2.0")
self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0]")
self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0]")
self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP [welcome :3.0]]")
self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT [welcome :3.0]]")
self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP goodbye :2.0")
self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT goodbye :2.0")
self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP welcome :3.0]")
self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT welcome :3.0]")
self.invalid_weight = neutral_prompt_parser.parse_root("hello :not_a_float")
def test_simple_prompt_child_count(self):
@ -24,6 +24,15 @@ class TestPromptParser(unittest.TestCase):
def test_simple_prompt_child_prompt(self):
self.assertEqual(self.simple_prompt.children[0].prompt, "hello ")
def test_square_weight_prompt(self):
prompt = "a [b c d e : f g h :1.5]"
parsed = neutral_prompt_parser.parse_root(prompt)
self.assertEqual(parsed.children[0].prompt, prompt)
composed_prompt = f"{prompt} AND_PERP other prompt"
parsed = neutral_prompt_parser.parse_root(composed_prompt)
self.assertEqual(parsed.children[0].prompt, prompt)
def test_and_prompt_child_count(self):
self.assertEqual(len(self.and_prompt.children), 2)
@ -38,12 +47,12 @@ class TestPromptParser(unittest.TestCase):
def test_and_perp_prompt_child_types(self):
self.assertIsInstance(self.and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt)
self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.LeafPrompt)
def test_and_perp_prompt_nested_child(self):
nested_child = self.and_perp_prompt.children[1].children[0]
nested_child = self.and_perp_prompt.children[1]
self.assertEqual(nested_child.weight, 2.0)
self.assertEqual(nested_child.prompt, "goodbye ")
self.assertEqual(nested_child.prompt.strip(), "goodbye")
def test_nested_and_perp_prompt_child_count(self):
self.assertEqual(len(self.nested_and_perp_prompt.children), 2)
@ -56,12 +65,12 @@ class TestPromptParser(unittest.TestCase):
nested_child = self.nested_and_perp_prompt.children[1].children[0]
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
nested_child = self.nested_and_perp_prompt.children[1].children[1]
self.assertIsInstance(nested_child, neutral_prompt_parser.CompositePrompt)
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
def test_nested_and_perp_prompt_nested_child(self):
nested_child = self.nested_and_perp_prompt.children[1].children[1].children[0]
nested_child = self.nested_and_perp_prompt.children[1].children[1]
self.assertEqual(nested_child.weight, 3.0)
self.assertEqual(nested_child.prompt, "welcome ")
self.assertEqual(nested_child.prompt.strip(), "welcome")
def test_invalid_weight_child_count(self):
self.assertEqual(len(self.invalid_weight.children), 1)
@ -77,12 +86,12 @@ class TestPromptParser(unittest.TestCase):
def test_and_salt_prompt_child_types(self):
self.assertIsInstance(self.and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt)
self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.LeafPrompt)
def test_and_salt_prompt_nested_child(self):
nested_child = self.and_salt_prompt.children[1].children[0]
nested_child = self.and_salt_prompt.children[1]
self.assertEqual(nested_child.weight, 2.0)
self.assertEqual(nested_child.prompt, "goodbye ")
self.assertEqual(nested_child.prompt.strip(), "goodbye")
def test_nested_and_salt_prompt_child_count(self):
self.assertEqual(len(self.nested_and_salt_prompt.children), 2)
@ -95,12 +104,12 @@ class TestPromptParser(unittest.TestCase):
nested_child = self.nested_and_salt_prompt.children[1].children[0]
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
nested_child = self.nested_and_salt_prompt.children[1].children[1]
self.assertIsInstance(nested_child, neutral_prompt_parser.CompositePrompt)
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
def test_nested_and_salt_prompt_nested_child(self):
nested_child = self.nested_and_salt_prompt.children[1].children[1].children[0]
nested_child = self.nested_and_salt_prompt.children[1].children[1]
self.assertEqual(nested_child.weight, 3.0)
self.assertEqual(nested_child.prompt, "welcome ")
self.assertEqual(nested_child.prompt.strip(), "welcome")
def test_start_with_prompt_editing(self):
prompt = "[(long shot:1.2):0.1] detail.."

View File

@ -23,15 +23,42 @@ class TestMaliciousPromptParser(unittest.TestCase):
self.assertEqual(result.children[0].weight, 1.0)
self.assertEqual(result.children[1].weight, -2.0)
def test_debalanced_square_brackets(self):
prompt = "a [ b " * 100
result = self.parser.parse_root(prompt)
self.assertEqual(result.children[0].prompt, prompt)
prompt = "a ] b " * 100
result = self.parser.parse_root(prompt)
self.assertEqual(result.children[0].prompt, prompt)
repeats = 10
prompt = "a [ [ b AND c ] " * repeats
result = self.parser.parse_root(prompt)
self.assertEqual([x.prompt for x in result.children], ["a [[ b ", *[" c ] a [[ b "] * (repeats - 1), " c ]"])
repeats = 10
prompt = "a [ b AND c ] ] " * repeats
result = self.parser.parse_root(prompt)
self.assertEqual([x.prompt for x in result.children], ["a [ b ", *[" c ]] a [ b "] * (repeats - 1), " c ]]"])
def test_erroneous_syntax(self):
result = self.parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0")
self.assertEqual(result.children[0].weight, 1.0)
self.assertEqual(result.children[1].children[0].prompt, "goodbye ")
self.assertEqual(result.children[1].children[0].weight, 2.0)
self.assertEqual(result.children[1].prompt, "[goodbye ")
self.assertEqual(result.children[1].weight, 2.0)
result = self.parser.parse_root("hello :1.0 AND_PERP goodbye :2.0]")
self.assertEqual(result.children[0].weight, 1.0)
self.assertEqual(result.children[1].children[0].prompt, " goodbye ")
self.assertEqual(result.children[1].prompt, " goodbye ")
result = self.parser.parse_root("hello :1.0 AND_PERP goodbye] :2.0")
self.assertEqual(result.children[1].prompt, " goodbye]")
self.assertEqual(result.children[1].weight, 2.0)
result = self.parser.parse_root("hello :1.0 AND_PERP a [ goodbye :2.0")
self.assertEqual(result.children[1].weight, 2.0)
self.assertEqual(result.children[1].prompt, " a [ goodbye ")
result = self.parser.parse_root("hello :1.0 AND_PERP AND goodbye :2.0")
self.assertEqual(result.children[0].weight, 1.0)
@ -58,13 +85,13 @@ class TestMaliciousPromptParser(unittest.TestCase):
self.assertIsInstance(result.children[1], neutral_prompt_parser.CompositePrompt)
def test_complex_nested_prompts(self):
complex_prompt = "hello :1.0 AND goodbye :2.0 AND_PERP [welcome :3.0 AND farewell :4.0 AND_PERP [greetings :5.0]]"
complex_prompt = "hello :1.0 AND goodbye :2.0 AND_PERP [welcome :3.0 AND farewell :4.0 AND_PERP greetings:5.0]"
result = self.parser.parse_root(complex_prompt)
self.assertEqual(result.children[0].weight, 1.0)
self.assertEqual(result.children[1].weight, 2.0)
self.assertEqual(result.children[2].children[0].weight, 3.0)
self.assertEqual(result.children[2].children[1].weight, 4.0)
self.assertEqual(result.children[2].children[2].children[0].weight, 5.0)
self.assertEqual(result.children[2].children[2].weight, 5.0)
def test_string_with_random_characters(self):
random_chars = "ASDFGHJKL:@#$/.,|}{><~`12[3]456AND_PERP7890"