diff --git a/lib_neutral_prompt/cfg_denoiser_hijack.py b/lib_neutral_prompt/cfg_denoiser_hijack.py index 2dde023..696a060 100644 --- a/lib_neutral_prompt/cfg_denoiser_hijack.py +++ b/lib_neutral_prompt/cfg_denoiser_hijack.py @@ -23,10 +23,10 @@ def combine_denoised_hijack( for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)): args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices) - cond_delta = prompt.accept(CondDeltaChildVisitor(), args, 0) - aux_cond_delta = prompt.accept(AuxCondDeltaChildVisitor(), args, cond_delta, 0) + cond_delta = prompt.accept(CondDeltaVisitor(), args, 0) + aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0) cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale - denoised[batch_i] = cfg_cond * get_cfg_rescale_factor(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta) + denoised[batch_i] = cfg_rescale(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta) return denoised @@ -41,22 +41,27 @@ def get_webui_denoised( uncond = x_out[-text_uncond.shape[0]:] sliced_batch_x_out = [] sliced_batch_cond_indices = [] + index_in = 0 for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)): args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices) - sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, len(sliced_batch_x_out)) - sliced_batch_cond_indices.append(sliced_cond_indices) + sliced_x_out, sliced_cond_indices = prompt.accept(GatherWebuiCondsVisitor(), args, index_in, len(sliced_batch_x_out)) + if sliced_cond_indices: + sliced_batch_cond_indices.append(sliced_cond_indices) sliced_batch_x_out.extend(sliced_x_out) + index_in += prompt.accept(neutral_prompt_parser.FlatSizeVisitor()) sliced_batch_x_out += list(uncond) sliced_batch_x_out = torch.stack(sliced_batch_x_out, dim=0) - sliced_batch_cond_indices = [il for il in sliced_batch_cond_indices if il] return original_function(sliced_batch_x_out, sliced_batch_cond_indices, text_uncond, cond_scale) -def get_cfg_rescale_factor(cfg_cond, cond): +def cfg_rescale(cfg_cond, cond): global_state.apply_and_clear_cfg_rescale_override() - return global_state.cfg_rescale * (torch.std(cond) / torch.std(cfg_cond) - 1) + 1 + cfg_cond_mean = cfg_cond.mean() + cfg_resacle_mean = (1 - global_state.cfg_rescale) * cfg_cond_mean + global_state.cfg_rescale * cond.mean() + cfg_rescale_factor = global_state.cfg_rescale * (cond.std() / cfg_cond.std() - 1) + 1 + return cfg_resacle_mean + (cfg_cond - cfg_cond_mean) * cfg_rescale_factor @dataclasses.dataclass @@ -68,67 +73,36 @@ class CombineDenoiseArgs: @dataclasses.dataclass class GatherWebuiCondsVisitor: - def visit_leaf_prompt(self, *args, **kwargs) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]: - return [], [] + def visit_leaf_prompt( + self, + that: neutral_prompt_parser.CompositePrompt, + args: CombineDenoiseArgs, + index_in: int, + index_out: int, + ) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]: + return [args.x_out[args.cond_indices[index_in][0]]], [(index_out, args.cond_indices[index_in][1])] def visit_composite_prompt( self, that: neutral_prompt_parser.CompositePrompt, args: CombineDenoiseArgs, - index_offset: int, + index_in: int, + index_out: int, ) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]: sliced_x_out = [] sliced_cond_indices = [] - index_in = 0 for child in that.children: - index_out = index_offset + len(sliced_x_out) - child_x_out, child_cond_indices = child.accept(GatherWebuiCondsVisitor.SingleCondVisitor(), args.x_out, args.cond_indices[index_in], index_out) - sliced_x_out.extend(child_x_out) - sliced_cond_indices.extend(child_cond_indices) + if child.conciliation is None: + index_offset = index_out + len(sliced_x_out) + child_x_out, child_cond_indices = child.accept(GatherWebuiCondsVisitor(), args, index_in, index_offset) + sliced_x_out.extend(child_x_out) + sliced_cond_indices.extend(child_cond_indices) + index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor()) return sliced_x_out, sliced_cond_indices - @dataclasses.dataclass - class SingleCondVisitor: - def visit_leaf_prompt( - self, - that: neutral_prompt_parser.LeafPrompt, - x_out: torch.Tensor, - cond_info: Tuple[int, float], - index: int, - ) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]: - return [x_out[cond_info[0]]], [(index, cond_info[1])] - - def visit_composite_prompt(self, *args, **kwargs) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]: - return [], [] - - -@dataclasses.dataclass -class CondDeltaChildVisitor: - def visit_leaf_prompt( - self, - that: neutral_prompt_parser.LeafPrompt, - args: CombineDenoiseArgs, - index: int, - ) -> torch.Tensor: - return torch.zeros_like(args.x_out[0]) - - def visit_composite_prompt( - self, - that: neutral_prompt_parser.CompositePrompt, - args: CombineDenoiseArgs, - index: int, - ) -> torch.Tensor: - cond_delta = torch.zeros_like(args.x_out[0]) - - for child in that.children: - cond_delta += child.weight * child.accept(CondDeltaVisitor(), args, index) - index += child.accept(neutral_prompt_parser.FlatSizeVisitor()) - - return cond_delta - @dataclasses.dataclass class CondDeltaVisitor: @@ -143,8 +117,9 @@ class CondDeltaVisitor: console_warn(f''' An unexpected noise weight was encountered at prompt #{index} Expected :{that.weight}, but got :{cond_info[1]} - This is likely due to another extension also monkey patching the webui noise blending function - Please open a github issue so that the conflict can be resolved + This is likely due to another extension also monkey patching the webui `combine_denoised` function + Please open a bug report here so that the conflict can be resolved: + https://github.com/ljleb/sd-webui-neutral-prompt/issues ''') return args.x_out[cond_info[0]] - args.uncond @@ -157,17 +132,19 @@ class CondDeltaVisitor: ) -> torch.Tensor: cond_delta = torch.zeros_like(args.x_out[0]) - if that.conciliation is None: - for child in that.children: - child_cond_delta = child.accept(CondDeltaChildVisitor(), args, index) - child_cond_delta += child.accept(AuxCondDeltaChildVisitor(), args, child_cond_delta, index) + for child in that.children: + if child.conciliation is None: + child_cond_delta = child.accept(CondDeltaVisitor(), args, index) + child_cond_delta += child.accept(AuxCondDeltaVisitor(), args, child_cond_delta, index) cond_delta += child.weight * child_cond_delta + index += child.accept(neutral_prompt_parser.FlatSizeVisitor()) + return cond_delta @dataclasses.dataclass -class AuxCondDeltaChildVisitor: +class AuxCondDeltaVisitor: def visit_leaf_prompt( self, that: neutral_prompt_parser.LeafPrompt, @@ -188,9 +165,10 @@ class AuxCondDeltaChildVisitor: salient_cond_deltas = [] for child in that.children: - child_cond_delta = child.accept(CondDeltaChildVisitor(), args, index) - child_cond_delta += child.accept(self, args, child_cond_delta, index) - if isinstance(child, neutral_prompt_parser.CompositePrompt): + if child.conciliation is not None: + child_cond_delta = child.accept(CondDeltaVisitor(), args, index) + child_cond_delta += child.accept(AuxCondDeltaVisitor(), args, child_cond_delta, index) + if child.conciliation == neutral_prompt_parser.ConciliationStrategy.PERPENDICULAR: aux_cond_delta += child.weight * get_perpendicular_component(cond_delta, child_cond_delta) elif child.conciliation == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK: @@ -221,12 +199,12 @@ def salient_blend(normal: torch.Tensor, vectors: List[Tuple[torch.Tensor, float] The blended result combines `normal` and vector information in salient regions. """ - salience_maps = [get_salience(normal)] + [get_salience(vector) for vector, weight in vectors] + salience_maps = [get_salience(normal)] + [get_salience(vector) for vector, _ in vectors] mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0) result = torch.zeros_like(normal) for mask_i, (vector, weight) in enumerate(vectors, start=1): - vector_mask = ((mask == mask_i).float()) + vector_mask = (mask == mask_i).float() result += weight * vector_mask * (vector - normal) return result diff --git a/lib_neutral_prompt/neutral_prompt_parser.py b/lib_neutral_prompt/neutral_prompt_parser.py index ccd5612..60d997e 100644 --- a/lib_neutral_prompt/neutral_prompt_parser.py +++ b/lib_neutral_prompt/neutral_prompt_parser.py @@ -5,23 +5,6 @@ from enum import Enum from typing import List, Tuple, Any, Optional -@dataclasses.dataclass -class PromptExpr(abc.ABC): - weight: float - - @abc.abstractmethod - def accept(self, visitor, *args, **kwargs) -> Any: - pass - - -@dataclasses.dataclass -class LeafPrompt(PromptExpr): - prompt: str - - def accept(self, visitor, *args, **kwargs): - return visitor.visit_leaf_prompt(self, *args, **kwargs) - - class PromptKeyword(Enum): AND = 'AND' AND_PERP = 'AND_PERP' @@ -41,10 +24,27 @@ class ConciliationStrategy(Enum): conciliation_strategies = [e.value for e in ConciliationStrategy] +@dataclasses.dataclass +class PromptExpr(abc.ABC): + weight: float + conciliation: Optional[ConciliationStrategy] + + @abc.abstractmethod + def accept(self, visitor, *args, **kwargs) -> Any: + pass + + +@dataclasses.dataclass +class LeafPrompt(PromptExpr): + prompt: str + + def accept(self, visitor, *args, **kwargs): + return visitor.visit_leaf_prompt(self, *args, **kwargs) + + @dataclasses.dataclass class CompositePrompt(PromptExpr): children: List[PromptExpr] - conciliation: Optional[ConciliationStrategy] def accept(self, visitor, *args, **kwargs): return visitor.visit_composite_prompt(self, *args, **kwargs) @@ -61,57 +61,53 @@ class FlatSizeVisitor: def parse_root(string: str) -> CompositePrompt: tokens = tokenize(string) prompts = parse_prompts(tokens) - return CompositePrompt(1., prompts, None) + return CompositePrompt(1., None, prompts) -def parse_prompts(tokens: List[str]) -> List[PromptExpr]: - prompts = [parse_prompt(tokens, first=True)] +def parse_prompts(tokens: List[str], *, nested: bool = False) -> List[PromptExpr]: + prompts = [parse_prompt(tokens, first=True, nested=nested)] while tokens: - if tokens[0] in [']']: + if nested and tokens[0] in [']']: break - prompts.append(parse_prompt(tokens, first=False)) + prompts.append(parse_prompt(tokens, first=False, nested=nested)) return prompts -def parse_prompt(tokens: List[str], *, first: bool) -> PromptExpr: - if first: - prompt_type = PromptKeyword.AND.value - else: - assert tokens[0] in prompt_keywords +def parse_prompt(tokens: List[str], *, first: bool, nested: bool = False) -> PromptExpr: + if not first and tokens[0] in prompt_keywords: prompt_type = tokens.pop(0) + else: + prompt_type = PromptKeyword.AND.value - tokens_copy = tokens.copy() - if tokens_copy and tokens_copy[0] == '[': - tokens_copy.pop(0) - prompts = parse_prompts(tokens_copy) - if tokens_copy: - assert tokens_copy.pop(0) == ']' - if not tokens_copy or tokens_copy[0] in prompt_keywords + [']']: - tokens[:] = tokens_copy - weight = parse_weight(tokens) - conciliation = ConciliationStrategy(prompt_type) if prompt_type in conciliation_strategies else None - return CompositePrompt(weight, prompts, conciliation) + tokens_copy = tokens.copy() + if tokens_copy and tokens_copy[0] == '[': + tokens_copy.pop(0) + prompts = parse_prompts(tokens_copy, nested=True) + if tokens_copy: + assert tokens_copy.pop(0) == ']' + if len(prompts) > 1: + tokens[:] = tokens_copy + weight = parse_weight(tokens) + conciliation = ConciliationStrategy(prompt_type) if prompt_type in conciliation_strategies else None + return CompositePrompt(weight, conciliation, prompts) - prompt_text, weight = parse_prompt_text(tokens) - prompt = LeafPrompt(weight, prompt_text) - if prompt_type in conciliation_strategies: - prompt.weight = 1. - prompt = CompositePrompt(weight, [prompt], ConciliationStrategy(prompt_type)) - - return prompt + prompt_text, weight = parse_prompt_text(tokens, nested=nested) + return LeafPrompt(weight, ConciliationStrategy(prompt_type) if prompt_type in conciliation_strategies else None, prompt_text) -def parse_prompt_text(tokens: List[str]) -> Tuple[str, float]: +def parse_prompt_text(tokens: List[str], *, nested: bool = False) -> Tuple[str, float]: text = '' depth = 0 weight = 1. while tokens: if tokens[0] == ']': if depth == 0: - break - depth -= 1 + if nested: + break + else: + depth -= 1 elif tokens[0] == '[': depth += 1 elif tokens[0] == ':': @@ -130,12 +126,9 @@ def parse_prompt_text(tokens: List[str]) -> Tuple[str, float]: def parse_weight(tokens: List[str]) -> float: weight = 1. - if tokens and tokens[0] == ':': + if len(tokens) >= 2 and tokens[0] == ':' and is_float(tokens[1]): tokens.pop(0) - if tokens: - weight_str = tokens.pop(0) - if is_float(weight_str): - weight = float(weight_str) + weight = float(tokens.pop(0)) return weight diff --git a/test/perp_parser/basic_test.py b/test/perp_parser/basic_test.py index 6b4d974..7d059b5 100644 --- a/test/perp_parser/basic_test.py +++ b/test/perp_parser/basic_test.py @@ -9,10 +9,10 @@ class TestPromptParser(unittest.TestCase): def setUp(self): self.simple_prompt = neutral_prompt_parser.parse_root("hello :1.0") self.and_prompt = neutral_prompt_parser.parse_root("hello AND goodbye :2.0") - self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0]") - self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0]") - self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP [welcome :3.0]]") - self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT [welcome :3.0]]") + self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP goodbye :2.0") + self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT goodbye :2.0") + self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP welcome :3.0]") + self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT welcome :3.0]") self.invalid_weight = neutral_prompt_parser.parse_root("hello :not_a_float") def test_simple_prompt_child_count(self): @@ -24,6 +24,15 @@ class TestPromptParser(unittest.TestCase): def test_simple_prompt_child_prompt(self): self.assertEqual(self.simple_prompt.children[0].prompt, "hello ") + def test_square_weight_prompt(self): + prompt = "a [b c d e : f g h :1.5]" + parsed = neutral_prompt_parser.parse_root(prompt) + self.assertEqual(parsed.children[0].prompt, prompt) + + composed_prompt = f"{prompt} AND_PERP other prompt" + parsed = neutral_prompt_parser.parse_root(composed_prompt) + self.assertEqual(parsed.children[0].prompt, prompt) + def test_and_prompt_child_count(self): self.assertEqual(len(self.and_prompt.children), 2) @@ -38,12 +47,12 @@ class TestPromptParser(unittest.TestCase): def test_and_perp_prompt_child_types(self): self.assertIsInstance(self.and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt) - self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt) + self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.LeafPrompt) def test_and_perp_prompt_nested_child(self): - nested_child = self.and_perp_prompt.children[1].children[0] + nested_child = self.and_perp_prompt.children[1] self.assertEqual(nested_child.weight, 2.0) - self.assertEqual(nested_child.prompt, "goodbye ") + self.assertEqual(nested_child.prompt.strip(), "goodbye") def test_nested_and_perp_prompt_child_count(self): self.assertEqual(len(self.nested_and_perp_prompt.children), 2) @@ -56,12 +65,12 @@ class TestPromptParser(unittest.TestCase): nested_child = self.nested_and_perp_prompt.children[1].children[0] self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt) nested_child = self.nested_and_perp_prompt.children[1].children[1] - self.assertIsInstance(nested_child, neutral_prompt_parser.CompositePrompt) + self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt) def test_nested_and_perp_prompt_nested_child(self): - nested_child = self.nested_and_perp_prompt.children[1].children[1].children[0] + nested_child = self.nested_and_perp_prompt.children[1].children[1] self.assertEqual(nested_child.weight, 3.0) - self.assertEqual(nested_child.prompt, "welcome ") + self.assertEqual(nested_child.prompt.strip(), "welcome") def test_invalid_weight_child_count(self): self.assertEqual(len(self.invalid_weight.children), 1) @@ -77,12 +86,12 @@ class TestPromptParser(unittest.TestCase): def test_and_salt_prompt_child_types(self): self.assertIsInstance(self.and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt) - self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt) + self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.LeafPrompt) def test_and_salt_prompt_nested_child(self): - nested_child = self.and_salt_prompt.children[1].children[0] + nested_child = self.and_salt_prompt.children[1] self.assertEqual(nested_child.weight, 2.0) - self.assertEqual(nested_child.prompt, "goodbye ") + self.assertEqual(nested_child.prompt.strip(), "goodbye") def test_nested_and_salt_prompt_child_count(self): self.assertEqual(len(self.nested_and_salt_prompt.children), 2) @@ -95,12 +104,12 @@ class TestPromptParser(unittest.TestCase): nested_child = self.nested_and_salt_prompt.children[1].children[0] self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt) nested_child = self.nested_and_salt_prompt.children[1].children[1] - self.assertIsInstance(nested_child, neutral_prompt_parser.CompositePrompt) + self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt) def test_nested_and_salt_prompt_nested_child(self): - nested_child = self.nested_and_salt_prompt.children[1].children[1].children[0] + nested_child = self.nested_and_salt_prompt.children[1].children[1] self.assertEqual(nested_child.weight, 3.0) - self.assertEqual(nested_child.prompt, "welcome ") + self.assertEqual(nested_child.prompt.strip(), "welcome") def test_start_with_prompt_editing(self): prompt = "[(long shot:1.2):0.1] detail.." diff --git a/test/perp_parser/malicious_test.py b/test/perp_parser/malicious_test.py index 82f359c..c6a43f5 100644 --- a/test/perp_parser/malicious_test.py +++ b/test/perp_parser/malicious_test.py @@ -23,15 +23,42 @@ class TestMaliciousPromptParser(unittest.TestCase): self.assertEqual(result.children[0].weight, 1.0) self.assertEqual(result.children[1].weight, -2.0) + def test_debalanced_square_brackets(self): + prompt = "a [ b " * 100 + result = self.parser.parse_root(prompt) + self.assertEqual(result.children[0].prompt, prompt) + + prompt = "a ] b " * 100 + result = self.parser.parse_root(prompt) + self.assertEqual(result.children[0].prompt, prompt) + + repeats = 10 + prompt = "a [ [ b AND c ] " * repeats + result = self.parser.parse_root(prompt) + self.assertEqual([x.prompt for x in result.children], ["a [[ b ", *[" c ] a [[ b "] * (repeats - 1), " c ]"]) + + repeats = 10 + prompt = "a [ b AND c ] ] " * repeats + result = self.parser.parse_root(prompt) + self.assertEqual([x.prompt for x in result.children], ["a [ b ", *[" c ]] a [ b "] * (repeats - 1), " c ]]"]) + def test_erroneous_syntax(self): result = self.parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0") self.assertEqual(result.children[0].weight, 1.0) - self.assertEqual(result.children[1].children[0].prompt, "goodbye ") - self.assertEqual(result.children[1].children[0].weight, 2.0) + self.assertEqual(result.children[1].prompt, "[goodbye ") + self.assertEqual(result.children[1].weight, 2.0) result = self.parser.parse_root("hello :1.0 AND_PERP goodbye :2.0]") self.assertEqual(result.children[0].weight, 1.0) - self.assertEqual(result.children[1].children[0].prompt, " goodbye ") + self.assertEqual(result.children[1].prompt, " goodbye ") + + result = self.parser.parse_root("hello :1.0 AND_PERP goodbye] :2.0") + self.assertEqual(result.children[1].prompt, " goodbye]") + self.assertEqual(result.children[1].weight, 2.0) + + result = self.parser.parse_root("hello :1.0 AND_PERP a [ goodbye :2.0") + self.assertEqual(result.children[1].weight, 2.0) + self.assertEqual(result.children[1].prompt, " a [ goodbye ") result = self.parser.parse_root("hello :1.0 AND_PERP AND goodbye :2.0") self.assertEqual(result.children[0].weight, 1.0) @@ -58,13 +85,13 @@ class TestMaliciousPromptParser(unittest.TestCase): self.assertIsInstance(result.children[1], neutral_prompt_parser.CompositePrompt) def test_complex_nested_prompts(self): - complex_prompt = "hello :1.0 AND goodbye :2.0 AND_PERP [welcome :3.0 AND farewell :4.0 AND_PERP [greetings :5.0]]" + complex_prompt = "hello :1.0 AND goodbye :2.0 AND_PERP [welcome :3.0 AND farewell :4.0 AND_PERP greetings:5.0]" result = self.parser.parse_root(complex_prompt) self.assertEqual(result.children[0].weight, 1.0) self.assertEqual(result.children[1].weight, 2.0) self.assertEqual(result.children[2].children[0].weight, 3.0) self.assertEqual(result.children[2].children[1].weight, 4.0) - self.assertEqual(result.children[2].children[2].children[0].weight, 5.0) + self.assertEqual(result.children[2].children[2].weight, 5.0) def test_string_with_random_characters(self): random_chars = "ASDFGHJKL:@#$/.,|}{><~`12[3]456AND_PERP7890"