parent
0803e74ee2
commit
ab8ec9e37b
|
|
@ -0,0 +1,53 @@
|
||||||
|
[tool.ruff]
|
||||||
|
# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
|
||||||
|
select = ["E", "F", "W", "C90", "I", "D", "N", "D201", "ASYNC", "PL", "EM", "EXE", "FBT", "ICN", "INP", "ISC", "PGH", "PIE", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP"]
|
||||||
|
ignore = ["D101", "D102", "D103", "D105", "D106", "D107", "D203", "D213", "SIM300"]
|
||||||
|
|
||||||
|
# Allow autofix for all enabled rules (when `--fix`) is provided.
|
||||||
|
fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
|
||||||
|
unfixable = ["F401", "F841"]
|
||||||
|
|
||||||
|
# Exclude a variety of commonly ignored directories.
|
||||||
|
exclude = [
|
||||||
|
".bzr",
|
||||||
|
".direnv",
|
||||||
|
".eggs",
|
||||||
|
".git",
|
||||||
|
".git-rewrite",
|
||||||
|
".hg",
|
||||||
|
".mypy_cache",
|
||||||
|
".nox",
|
||||||
|
".pants.d",
|
||||||
|
".pytype",
|
||||||
|
".ruff_cache",
|
||||||
|
".svn",
|
||||||
|
".tox",
|
||||||
|
".venv",
|
||||||
|
"__pypackages__",
|
||||||
|
"_build",
|
||||||
|
"buck-out",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
"node_modules",
|
||||||
|
"venv",
|
||||||
|
]
|
||||||
|
per-file-ignores = {}
|
||||||
|
|
||||||
|
# Same as Black.
|
||||||
|
line-length = 88
|
||||||
|
|
||||||
|
# Allow unused variables when underscore-prefixed.
|
||||||
|
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||||
|
|
||||||
|
# Assume Python 3.10.
|
||||||
|
target-version = "py310"
|
||||||
|
|
||||||
|
[tool.ruff.mccabe]
|
||||||
|
# Unlike Flake8, default to a complexity level of 10.
|
||||||
|
max-complexity = 10
|
||||||
|
|
||||||
|
[tool.ruff.isort]
|
||||||
|
lines-after-imports = 2
|
||||||
|
|
||||||
|
[rools.ruff.pydocstyle]
|
||||||
|
convention = "pep257"
|
||||||
|
|
@ -1,51 +1,53 @@
|
||||||
import gradio as gr
|
|
||||||
import regex as re
|
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
|
||||||
from modules import script_callbacks, shared
|
import gradio as gr
|
||||||
import modules.scripts as scripts
|
import regex as re
|
||||||
|
from modules import script_callbacks, scripts, shared
|
||||||
|
|
||||||
'''
|
|
||||||
|
"""
|
||||||
Formatting settings
|
Formatting settings
|
||||||
'''
|
"""
|
||||||
SPACE_COMMAS = True
|
SPACE_COMMAS = True
|
||||||
BRACKET2WEIGHT = True
|
BRACKET2WEIGHT = True
|
||||||
SPACE2UNDERSCORE = False
|
SPACE2UNDERSCORE = False
|
||||||
|
|
||||||
'''
|
"""
|
||||||
Regex stuff
|
Regex stuff
|
||||||
'''
|
"""
|
||||||
brackets_opening = '([{<'
|
brackets_opening = "([{<"
|
||||||
brackets_closing = ')]}>'
|
brackets_closing = ")]}>"
|
||||||
|
|
||||||
re_tokenize = re.compile(r'\s*,\s*')
|
re_tokenize = re.compile(r"\s*,\s*")
|
||||||
re_comma_spacing = re.compile(r',+')
|
re_comma_spacing = re.compile(r",+")
|
||||||
re_brackets_fix_whitespace = re.compile(r'([\(\[{<])\s*|\s*([\)\]}>}])')
|
re_brackets_fix_whitespace = re.compile(r"([\(\[{<])\s*|\s*([\)\]}>}])")
|
||||||
re_opposing_brackets = re.compile(r'([)\]}>])([([{<])')
|
re_opposing_brackets = re.compile(r"([)\]}>])([([{<])")
|
||||||
re_networks = re.compile(r'<.+?>')
|
re_networks = re.compile(r"<.+?>")
|
||||||
re_bracket_open = re.compile(r'[(\[]')
|
re_bracket_open = re.compile(r"[(\[]")
|
||||||
re_brackets_open = re.compile(r'\(+|\[+')
|
re_brackets_open = re.compile(r"\(+|\[+")
|
||||||
re_brackets_closing = re.compile(r'\)+|\]+')
|
re_brackets_closing = re.compile(r"\)+|\]+")
|
||||||
re_colon_spacing = re.compile(r'\s*(:)\s*')
|
re_colon_spacing = re.compile(r"\s*(:)\s*")
|
||||||
re_colon_spacing_composite = re.compile(r'\s*(:)\s*(?=\d*?\.?\d*?\s*?AND)')
|
re_colon_spacing_composite = re.compile(r"\s*(:)\s*(?=\d*?\.?\d*?\s*?AND)")
|
||||||
re_colon_spacing_comp_end = re.compile(r'(?<=AND[^:]*?)(:)(?=[^:]*$)')
|
re_colon_spacing_comp_end = re.compile(r"(?<=AND[^:]*?)(:)(?=[^:]*$)")
|
||||||
re_paren_weights_exist = re.compile(r'\(.*(?<!:):\d.?\d*\)+')
|
re_paren_weights_exist = re.compile(r"\(.*(?<!:):\d.?\d*\)+")
|
||||||
re_is_prompt_editing = re.compile(r'\[.*:.*\]')
|
re_is_prompt_editing = re.compile(r"\[.*:.*\]")
|
||||||
re_is_prompt_alternating = re.compile(r'\[.*|.*\]')
|
re_is_prompt_alternating = re.compile(r"\[.*|.*\]")
|
||||||
re_is_wildcard = re.compile(r'{.*}')
|
re_is_wildcard = re.compile(r"{.*}")
|
||||||
re_AND = re.compile(r'(.*?)\s*(AND)\s*(.*?)')
|
re_and = re.compile(r"(.*?)\s*(AND)\s*(.*?)")
|
||||||
re_pipe = re.compile(r'\s*(\|)\s*')
|
re_pipe = re.compile(r"\s*(\|)\s*")
|
||||||
re_existing_weight = re.compile(r'(?<=:)(\d+.?\d*|\d*.?\d+)(?=[)\]]$)')
|
re_existing_weight = re.compile(r"(?<=:)(\d+.?\d*|\d*.?\d+)(?=[)\]]$)")
|
||||||
|
|
||||||
'''
|
"""
|
||||||
References
|
References
|
||||||
'''
|
"""
|
||||||
ui_prompts = []
|
ui_prompts = []
|
||||||
|
|
||||||
|
|
||||||
'''
|
"""
|
||||||
Functions
|
Functions
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_bracket_closing(c: str):
|
def get_bracket_closing(c: str):
|
||||||
return brackets_closing[brackets_opening.find(c)]
|
return brackets_closing[brackets_opening.find(c)]
|
||||||
|
|
||||||
|
|
@ -63,7 +65,7 @@ def tokenize(data: str) -> list:
|
||||||
|
|
||||||
|
|
||||||
def remove_whitespace_excessive(prompt: str):
|
def remove_whitespace_excessive(prompt: str):
|
||||||
return ' '.join(prompt.split())
|
return " ".join(prompt.split())
|
||||||
|
|
||||||
|
|
||||||
def align_brackets(prompt: str):
|
def align_brackets(prompt: str):
|
||||||
|
|
@ -73,37 +75,36 @@ def align_brackets(prompt: str):
|
||||||
return re_brackets_fix_whitespace.sub(helper, prompt)
|
return re_brackets_fix_whitespace.sub(helper, prompt)
|
||||||
|
|
||||||
|
|
||||||
def space_AND(prompt: str):
|
def space_and(prompt: str):
|
||||||
def helper(match: re.Match):
|
def helper(match: re.Match):
|
||||||
return ' '.join(match.groups())
|
return " ".join(match.groups())
|
||||||
|
|
||||||
return re_AND.sub(helper, prompt)
|
return re_and.sub(helper, prompt)
|
||||||
|
|
||||||
|
|
||||||
def align_colons(prompt: str):
|
def align_colons(prompt: str):
|
||||||
def normalize(match: re.Match):
|
def normalize(match: re.Match):
|
||||||
return match.group(1)
|
return match.group(1)
|
||||||
|
|
||||||
def composite(match: re.Match):
|
def composite(match: re.Match):
|
||||||
return ' ' + match.group(1)
|
return " " + match.group(1)
|
||||||
|
|
||||||
def composite_end(match: re.Match):
|
def composite_end(match: re.Match):
|
||||||
return ' ' + match.group(1)
|
return " " + match.group(1)
|
||||||
|
|
||||||
ret = re_colon_spacing.sub(normalize, prompt)
|
ret = re_colon_spacing.sub(normalize, prompt)
|
||||||
ret = re_colon_spacing_composite.sub(composite, ret)
|
ret = re_colon_spacing_composite.sub(composite, ret)
|
||||||
ret = re_colon_spacing_comp_end.sub(composite_end, ret)
|
return re_colon_spacing_comp_end.sub(composite_end, ret)
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def align_commas(prompt: str):
|
def align_commas(prompt: str):
|
||||||
if not SPACE_COMMAS:
|
if not SPACE_COMMAS:
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
split = re_comma_spacing.split(prompt)
|
split = re_comma_spacing.split(prompt)
|
||||||
split = map(str.strip, split)
|
split = map(str.strip, split)
|
||||||
split = filter(None, split)
|
split = filter(None, split)
|
||||||
return ', '.join(split)
|
return ", ".join(split)
|
||||||
|
|
||||||
|
|
||||||
def extract_networks(tokens: list):
|
def extract_networks(tokens: list):
|
||||||
|
|
@ -111,14 +112,14 @@ def extract_networks(tokens: list):
|
||||||
|
|
||||||
|
|
||||||
def remove_networks(tokens: list):
|
def remove_networks(tokens: list):
|
||||||
return list(filter(lambda token : not re_networks.match(token), tokens))
|
return list(filter(lambda token: not re_networks.match(token), tokens))
|
||||||
|
|
||||||
|
|
||||||
def remove_mismatched_brackets(prompt: str):
|
def remove_mismatched_brackets(prompt: str):
|
||||||
stack = []
|
stack = []
|
||||||
pos = []
|
pos = []
|
||||||
ret = ''
|
ret = ""
|
||||||
|
|
||||||
for i, c in enumerate(prompt):
|
for i, c in enumerate(prompt):
|
||||||
if c in brackets_opening:
|
if c in brackets_opening:
|
||||||
stack.append(c)
|
stack.append(c)
|
||||||
|
|
@ -133,25 +134,25 @@ def remove_mismatched_brackets(prompt: str):
|
||||||
ret += c
|
ret += c
|
||||||
else:
|
else:
|
||||||
ret += c
|
ret += c
|
||||||
|
|
||||||
while stack:
|
while stack:
|
||||||
bracket = stack.pop()
|
bracket = stack.pop()
|
||||||
p = pos.pop()
|
p = pos.pop()
|
||||||
ret = ret[:p] + ret[p+1:]
|
ret = ret[:p] + ret[p + 1 :]
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def space_bracekts(prompt: str):
|
def space_bracekts(prompt: str):
|
||||||
def helper(match : re.Match):
|
def helper(match: re.Match):
|
||||||
# print(' '.join(match.groups()))
|
# print(' '.join(match.groups()))
|
||||||
return ' '.join(match.groups())
|
return " ".join(match.groups())
|
||||||
|
|
||||||
# print(prompt)
|
# print(prompt)
|
||||||
return re_opposing_brackets.sub(helper, prompt)
|
return re_opposing_brackets.sub(helper, prompt)
|
||||||
|
|
||||||
|
|
||||||
def align_alternating(prompt:str):
|
def align_alternating(prompt: str):
|
||||||
def helper(match: re.Match):
|
def helper(match: re.Match):
|
||||||
return match.group(1)
|
return match.group(1)
|
||||||
|
|
||||||
|
|
@ -159,8 +160,9 @@ def align_alternating(prompt:str):
|
||||||
|
|
||||||
|
|
||||||
def bracket_to_weights(prompt: str):
|
def bracket_to_weights(prompt: str):
|
||||||
"""
|
"""Convert excessive brackets to weight.
|
||||||
when scanning, we need a way to ignore prompt editing, composable, and alternating
|
|
||||||
|
When scanning, we need a way to ignore prompt editing, composable, and alternating
|
||||||
we still need to weigh their individual words within them, however...
|
we still need to weigh their individual words within them, however...
|
||||||
|
|
||||||
use a depth counter to ensure that we find closing brackets
|
use a depth counter to ensure that we find closing brackets
|
||||||
|
|
@ -171,7 +173,7 @@ def bracket_to_weights(prompt: str):
|
||||||
the length of the parts we're working on... however, if we do this, then we can't
|
the length of the parts we're working on... however, if we do this, then we can't
|
||||||
remove consecutive brackets of the same type, we we would need to remove bracketing
|
remove consecutive brackets of the same type, we we would need to remove bracketing
|
||||||
to the left of the part of the string we're working on.
|
to the left of the part of the string we're working on.
|
||||||
|
|
||||||
well, i think we should be fine with a while pos != end of string, and if we find
|
well, i think we should be fine with a while pos != end of string, and if we find
|
||||||
a weight to add, break from the enumerate loop and resume at position to re-enumerate
|
a weight to add, break from the enumerate loop and resume at position to re-enumerate
|
||||||
the new string
|
the new string
|
||||||
|
|
@ -194,26 +196,26 @@ def bracket_to_weights(prompt: str):
|
||||||
remove excessive bracket
|
remove excessive bracket
|
||||||
convert bracket to ()
|
convert bracket to ()
|
||||||
|
|
||||||
IF BRACKETS ARE CONSECUTIVE, AND AFTER THEIR SLOPE, BOTH THEIR
|
IF BRACKETS ARE CONSECUTIVE, AND AFTER THEIR SLOPE, BOTH THEIR
|
||||||
INNER-NEXT DEPTH ARE THE SAME, IT IS A WEIGHT.
|
INNER-NEXT DEPTH ARE THE SAME, IT IS A WEIGHT.
|
||||||
|
|
||||||
Example using map_depth.
|
Example using map_depth.
|
||||||
c, ((a, b))
|
c, ((a, b))
|
||||||
(( ))
|
(( ))
|
||||||
00012222210
|
00012222210
|
||||||
---^^----vv
|
---^^----vv
|
||||||
2 ____ 2
|
2 ____ 2
|
||||||
1 /===>\ 1
|
1 /===>\\ 1
|
||||||
0___/=====>\0
|
0___/=====>\0
|
||||||
Because 01 can meet on the other side, these are matching
|
Because 01 can meet on the other side, these are matching
|
||||||
|
|
||||||
c, (a, (b))
|
c, (a, (b))
|
||||||
( ( ))
|
( ( ))
|
||||||
00011112210
|
00011112210
|
||||||
---^---^-vv
|
---^---^-vv
|
||||||
2 _ 2
|
2 _ 2
|
||||||
1 ___/>\ 1
|
1 ___/>\\ 1
|
||||||
0___/=====>\0
|
0___/=====>\0
|
||||||
0 and 1 match, but since gradients are not exactly mirrored,
|
0 and 1 match, but since gradients are not exactly mirrored,
|
||||||
thier weights should not be combined.
|
thier weights should not be combined.
|
||||||
|
|
||||||
|
|
@ -221,28 +223,28 @@ def bracket_to_weights(prompt: str):
|
||||||
(( ) )
|
(( ) )
|
||||||
00012211110
|
00012211110
|
||||||
---^^-v---v
|
---^^-v---v
|
||||||
2 _ 2
|
2 _ 2
|
||||||
1 /=\___ 1
|
1 /=\\___ 1
|
||||||
0___/=====>\0
|
0___/=====>\0
|
||||||
Similar idea to above example.
|
Similar idea to above example.
|
||||||
|
|
||||||
c, ((a), ((b)))
|
c, ((a), ((b)))
|
||||||
(( ) (( )))
|
(( ) (( )))
|
||||||
000122111233210
|
000122111233210
|
||||||
---^^-v--^^-vvv
|
---^^-v--^^-vvv
|
||||||
3 _ 3
|
3 _ 3
|
||||||
2 _ />\ 2
|
2 _ />\\ 2
|
||||||
1 />\__/==>\ 1
|
1 />\\__/==>\\ 1
|
||||||
0___/=========>\0
|
0___/=========>\0
|
||||||
Tricky one. Here, 01 open together, so there's a potential that their
|
Tricky one. Here, 01 open together, so there's a potential that their
|
||||||
weights should be combined if they close together, but instead 1 closes
|
weights should be combined if they close together, but instead 1 closes
|
||||||
early. We only need to check for closure initial checking depth - 1.
|
early. We only need to check for closure initial checking depth - 1.
|
||||||
|
|
||||||
"""
|
""" # noqa: D301
|
||||||
if not BRACKET2WEIGHT:
|
if not BRACKET2WEIGHT:
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
re_existing_weight = re.compile(r'(:\d+.?\d*)[)\]]$')
|
re_existing_weight = re.compile(r"(:\d+.?\d*)[)\]]$")
|
||||||
depths, gradients, brackets = get_mappings(prompt)
|
depths, gradients, brackets = get_mappings(prompt)
|
||||||
|
|
||||||
pos = 0
|
pos = 0
|
||||||
|
|
@ -251,67 +253,89 @@ def bracket_to_weights(prompt: str):
|
||||||
|
|
||||||
while pos < len(ret):
|
while pos < len(ret):
|
||||||
current_position = ret[pos:]
|
current_position = ret[pos:]
|
||||||
if ret[pos] in '([':
|
if ret[pos] in "([":
|
||||||
open_bracketing = re_brackets_open.match(ret, pos)
|
open_bracketing = re_brackets_open.match(ret, pos)
|
||||||
consecutive = len(open_bracketing.group(0))
|
consecutive = len(open_bracketing.group(0))
|
||||||
gradient_search = ''.join(map(str, reversed(range(int(depths[pos])-1, int(depths[pos])+consecutive))))
|
gradient_search = "".join(
|
||||||
is_square_brackets = True if '[' in open_bracketing.group(0) else False
|
map(
|
||||||
|
str,
|
||||||
|
reversed(
|
||||||
|
range(int(depths[pos]) - 1, int(depths[pos]) + consecutive)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
is_square_brackets = "[" in open_bracketing.group(0)
|
||||||
|
|
||||||
insert_at, weight, valid_consecutive = get_weight(
|
insert_at, weight, valid_consecutive = get_weight(
|
||||||
ret,
|
ret,
|
||||||
gradients,
|
gradients,
|
||||||
depths,
|
depths,
|
||||||
brackets,
|
brackets,
|
||||||
open_bracketing.end(),
|
open_bracketing.end(),
|
||||||
consecutive,
|
consecutive,
|
||||||
gradient_search,
|
gradient_search,
|
||||||
is_square_brackets)
|
is_square_brackets,
|
||||||
|
)
|
||||||
|
|
||||||
if weight:
|
if weight:
|
||||||
# If weight already exists, ignore
|
# If weight already exists, ignore
|
||||||
current_weight = re_existing_weight.search(ret[:insert_at + 1])
|
current_weight = re_existing_weight.search(ret[: insert_at + 1])
|
||||||
if current_weight:
|
if current_weight:
|
||||||
ret = ret[:open_bracketing.start()] + '(' + ret[open_bracketing.start()+valid_consecutive:insert_at] + ')' + ret[insert_at + consecutive:]
|
ret = (
|
||||||
|
ret[: open_bracketing.start()]
|
||||||
|
+ "("
|
||||||
|
+ ret[open_bracketing.start() + valid_consecutive : insert_at]
|
||||||
|
+ ")"
|
||||||
|
+ ret[insert_at + consecutive :]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
ret = ret[:open_bracketing.start()] + '(' + ret[open_bracketing.start()+valid_consecutive:insert_at] + f':{weight:.2f}' + ')' + ret[insert_at + consecutive:]
|
ret = (
|
||||||
|
ret[: open_bracketing.start()]
|
||||||
|
+ "("
|
||||||
|
+ ret[open_bracketing.start() + valid_consecutive : insert_at]
|
||||||
|
+ f":{weight:.2f}"
|
||||||
|
+ ")"
|
||||||
|
+ ret[insert_at + consecutive :]
|
||||||
|
)
|
||||||
|
|
||||||
depths, gradients, brackets = get_mappings(ret)
|
depths, gradients, brackets = get_mappings(ret)
|
||||||
pos += 1
|
pos += 1
|
||||||
|
|
||||||
match = re_bracket_open.search(ret, pos)
|
match = re_bracket_open.search(ret, pos)
|
||||||
|
|
||||||
if not match: # no more potential weight brackets to parse
|
if not match: # no more potential weight brackets to parse
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
pos = match.start()
|
pos = match.start()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def depth_to_map(s: str):
|
def depth_to_map(s: str):
|
||||||
ret = ''
|
ret = ""
|
||||||
depth = 0
|
depth = 0
|
||||||
for c in s:
|
for c in s:
|
||||||
if c in '([':
|
if c in "([":
|
||||||
depth += 1
|
depth += 1
|
||||||
if c in ')]':
|
if c in ")]":
|
||||||
depth -= 1
|
depth -= 1
|
||||||
ret += str(depth)
|
ret += str(depth)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def depth_to_gradeint(s: str):
|
def depth_to_gradeint(s: str):
|
||||||
ret = ''
|
ret = ""
|
||||||
for c in s:
|
for c in s:
|
||||||
if c in '([':
|
if c in "([":
|
||||||
ret += str('^')
|
ret += "^"
|
||||||
elif c in ')]':
|
elif c in ")]":
|
||||||
ret += str('v')
|
ret += "v"
|
||||||
else:
|
else:
|
||||||
ret += str('-')
|
ret += "-"
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def filter_brackets(s: str):
|
def filter_brackets(s: str):
|
||||||
return ''.join(list(map(lambda c : c if c in '[]()' else ' ', s)))
|
return "".join(list(map(lambda c: c if c in "[]()" else " ", s)))
|
||||||
|
|
||||||
|
|
||||||
def get_mappings(s: str):
|
def get_mappings(s: str):
|
||||||
|
|
@ -319,56 +343,66 @@ def get_mappings(s: str):
|
||||||
|
|
||||||
|
|
||||||
def calculate_weight(d: str, is_square_brackets: bool):
|
def calculate_weight(d: str, is_square_brackets: bool):
|
||||||
return 1 / 1.1 ** int(d) if is_square_brackets else 1 * 1.1 ** int(d)
|
return 1 / 1.1 ** int(d) if is_square_brackets else 1 * 1.1 ** int(d)
|
||||||
|
|
||||||
|
|
||||||
def get_weight(prompt: str, map_gradient: list, map_depth: list, map_brackets: list, pos: int, ctv: int, gradient_search: str, is_square_brackets :bool=False):
|
def get_weight(
|
||||||
'''
|
prompt: str,
|
||||||
Returns 0 if bracket was recognized as prompt editing, alternation, or composable
|
map_gradient: list,
|
||||||
'''
|
map_depth: list,
|
||||||
|
map_brackets: list,
|
||||||
|
pos: int,
|
||||||
|
ctv: int,
|
||||||
|
gradient_search: str,
|
||||||
|
is_square_brackets: bool = False,
|
||||||
|
):
|
||||||
|
"""Returns 0 if bracket was recognized as prompt editing, alternation, or composable."""
|
||||||
# CURRENTLY DOES NOT TAKE INTO ACCOUNT COMPOSABLE?? DO WE EVEN NEED TO?
|
# CURRENTLY DOES NOT TAKE INTO ACCOUNT COMPOSABLE?? DO WE EVEN NEED TO?
|
||||||
# E.G. [a AND B :1.2] == (a AND B:1.1) != (a AND B:1.1) ????
|
# E.G. [a AND B :1.2] == (a AND B:1.1) != (a AND B:1.1) ????
|
||||||
while pos+ctv <= len(prompt):
|
while pos + ctv <= len(prompt):
|
||||||
if ctv == 0:
|
if ctv == 0:
|
||||||
return prompt, 0, 1
|
return prompt, 0, 1
|
||||||
a, b = pos, pos+ctv
|
a, b = pos, pos + ctv
|
||||||
if prompt[a] in ':|' and is_square_brackets:
|
if prompt[a] in ":|" and is_square_brackets:
|
||||||
if map_depth[-2] == map_depth[a]:
|
if map_depth[-2] == map_depth[a]:
|
||||||
return prompt, 0, 1
|
return prompt, 0, 1
|
||||||
if map_depth[a] in gradient_search:
|
if map_depth[a] in gradient_search:
|
||||||
gradient_search = gradient_search.replace(map_depth[a], '')
|
gradient_search = gradient_search.replace(map_depth[a], "")
|
||||||
ctv -= 1
|
ctv -= 1
|
||||||
elif (map_gradient[a:b] == 'v' * ctv and
|
elif map_gradient[a:b] == "v" * ctv and map_depth[a - 1 : b] == gradient_search:
|
||||||
map_depth[a-1:b] == gradient_search):
|
|
||||||
return a, calculate_weight(ctv, is_square_brackets), ctv
|
return a, calculate_weight(ctv, is_square_brackets), ctv
|
||||||
elif ('v' == map_gradient[a] and
|
elif "v" == map_gradient[a] and map_depth[a - 1 : b - 1] in gradient_search:
|
||||||
map_depth[a-1:b-1] in gradient_search):
|
narrowing = map_gradient[a:b].count("v")
|
||||||
narrowing = map_gradient[a:b].count('v')
|
|
||||||
gradient_search = gradient_search[narrowing:]
|
gradient_search = gradient_search[narrowing:]
|
||||||
ctv -= 1
|
ctv -= 1
|
||||||
pos += 1
|
pos += 1
|
||||||
|
|
||||||
raise Exception(f'Somehow weight index searching has gone outside of prompt length with prompt: {prompt}')
|
msg = f"Somehow weight index searching has gone outside of prompt length with prompt: {prompt}"
|
||||||
|
raise Exception(msg)
|
||||||
|
|
||||||
|
|
||||||
def space_to_underscore(prompt: str):
|
def space_to_underscore(prompt: str):
|
||||||
# We need to look ahead and ignore any spaces/underscores within network tokens
|
# We need to look ahead and ignore any spaces/underscores within network tokens
|
||||||
# INPUT <lora:chicken butt>, multiple subjects
|
# INPUT <lora:chicken butt>, multiple subjects
|
||||||
# OUTPUT <lora:chicken butt>, multiple_subjects
|
# OUTPUT <lora:chicken butt>, multiple_subjects
|
||||||
match = r'(?<!BREAK) +(?!BREAK|[^<]*>)' if SPACE2UNDERSCORE else r'(?<!BREAK)_+(?!BREAK|[^<]*>)'
|
match = (
|
||||||
replace = '_' if SPACE2UNDERSCORE else ' '
|
r"(?<!BREAK) +(?!BREAK|[^<]*>)"
|
||||||
|
if SPACE2UNDERSCORE
|
||||||
|
else r"(?<!BREAK|_)_(?!_|BREAK|[^<]*>)"
|
||||||
|
)
|
||||||
|
replace = "_" if SPACE2UNDERSCORE else " "
|
||||||
|
|
||||||
tokens: str = tokenize(prompt)
|
tokens: str = tokenize(prompt)
|
||||||
|
|
||||||
return ','.join(map(lambda t: re.sub(match, replace, t), tokens))
|
return ",".join(map(lambda t: re.sub(match, replace, t), tokens))
|
||||||
|
|
||||||
|
|
||||||
def escape_bracket_index(token, symbols, start_index = 0):
|
def escape_bracket_index(token, symbols, start_index=0):
|
||||||
# Given a token and a set of open bracket symbols, find the index in which that character
|
# Given a token and a set of open bracket symbols, find the index in which that character
|
||||||
# escapes the given bracketing such that depth = 0.
|
# escapes the given bracketing such that depth = 0.
|
||||||
token_length = len(token)
|
token_length = len(token)
|
||||||
open = symbols
|
open = symbols
|
||||||
close = ''
|
close = ""
|
||||||
for s in symbols:
|
for s in symbols:
|
||||||
close += brackets_closing[brackets_opening.index(s)]
|
close += brackets_closing[brackets_opening.index(s)]
|
||||||
|
|
||||||
|
|
@ -382,7 +416,7 @@ def escape_bracket_index(token, symbols, start_index = 0):
|
||||||
if d == 0:
|
if d == 0:
|
||||||
return i
|
return i
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
return i
|
return i
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -392,8 +426,8 @@ def format_prompt(*prompts: list):
|
||||||
ret = []
|
ret = []
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
if not prompt or prompt.strip() == '':
|
if not prompt or prompt.strip() == "":
|
||||||
ret.append('')
|
ret.append("")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Clean up the string
|
# Clean up the string
|
||||||
|
|
@ -404,7 +438,7 @@ def format_prompt(*prompts: list):
|
||||||
prompt = remove_whitespace_excessive(prompt)
|
prompt = remove_whitespace_excessive(prompt)
|
||||||
prompt = space_to_underscore(prompt)
|
prompt = space_to_underscore(prompt)
|
||||||
prompt = align_brackets(prompt)
|
prompt = align_brackets(prompt)
|
||||||
prompt = space_AND(prompt) # for proper compositing alignment on colons
|
prompt = space_and(prompt) # for proper compositing alignment on colons
|
||||||
prompt = space_bracekts(prompt)
|
prompt = space_bracekts(prompt)
|
||||||
prompt = align_colons(prompt)
|
prompt = align_colons(prompt)
|
||||||
prompt = align_commas(prompt)
|
prompt = align_commas(prompt)
|
||||||
|
|
@ -412,56 +446,60 @@ def format_prompt(*prompts: list):
|
||||||
prompt = bracket_to_weights(prompt)
|
prompt = bracket_to_weights(prompt)
|
||||||
|
|
||||||
ret.append(prompt)
|
ret.append(prompt)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def on_before_component(component: gr.component, **kwargs: dict):
|
def on_before_component(component: gr.component, **kwargs: dict):
|
||||||
if 'elem_id' in kwargs:
|
if "elem_id" in kwargs:
|
||||||
if kwargs['elem_id'] in ['txt2img_prompt', 'txt2img_neg_prompt', 'img2img_prompt', 'img2img_neg_prompt']:
|
if kwargs["elem_id"] in [
|
||||||
|
"txt2img_prompt",
|
||||||
|
"txt2img_neg_prompt",
|
||||||
|
"img2img_prompt",
|
||||||
|
"img2img_neg_prompt",
|
||||||
|
]:
|
||||||
ui_prompts.append(component)
|
ui_prompts.append(component)
|
||||||
elif kwargs['elem_id'] == 'paste':
|
return None
|
||||||
|
elif kwargs["elem_id"] == "paste":
|
||||||
with gr.Blocks(analytics_enabled=False) as ui_component:
|
with gr.Blocks(analytics_enabled=False) as ui_component:
|
||||||
button = gr.Button(value='🪄', elem_classes='tool', elem_id='format')
|
button = gr.Button(value="🪄", elem_classes="tool", elem_id="format")
|
||||||
button.click(
|
button.click(fn=format_prompt, inputs=ui_prompts, outputs=ui_prompts)
|
||||||
fn=format_prompt,
|
|
||||||
inputs=ui_prompts,
|
|
||||||
outputs=ui_prompts
|
|
||||||
)
|
|
||||||
return ui_component
|
return ui_component
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def on_ui_settings():
|
def on_ui_settings():
|
||||||
section = ('pformat', 'Prompt Formatter')
|
section = ("pformat", "Prompt Formatter")
|
||||||
shared.opts.add_option(
|
shared.opts.add_option(
|
||||||
'pformat_space_commas',
|
"pformat_space_commas",
|
||||||
shared.OptionInfo(
|
shared.OptionInfo(
|
||||||
True,
|
True,
|
||||||
'Add a spaces after comma',
|
"Add a spaces after comma",
|
||||||
gr.Checkbox,
|
gr.Checkbox,
|
||||||
{'interactive': True},
|
{"interactive": True},
|
||||||
section=section
|
section=section,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
shared.opts.add_option(
|
shared.opts.add_option(
|
||||||
'pfromat_bracket2weight',
|
"pfromat_bracket2weight",
|
||||||
shared.OptionInfo(
|
shared.OptionInfo(
|
||||||
True,
|
True,
|
||||||
'Convert excessive brackets to weights',
|
"Convert excessive brackets to weights",
|
||||||
gr.Checkbox,
|
gr.Checkbox,
|
||||||
{'interactive': True},
|
{"interactive": True},
|
||||||
section=section
|
section=section,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
shared.opts.add_option(
|
shared.opts.add_option(
|
||||||
'pfromat_space2underscore',
|
"pfromat_space2underscore",
|
||||||
shared.OptionInfo(
|
shared.OptionInfo(
|
||||||
False,
|
False,
|
||||||
'Convert spaces to underscores (default: underscore to spaces)',
|
"Convert spaces to underscores (default: underscore to spaces)",
|
||||||
gr.Checkbox,
|
gr.Checkbox,
|
||||||
{'interactive': True},
|
{"interactive": True},
|
||||||
section=section
|
section=section,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
sync_settings()
|
sync_settings()
|
||||||
|
|
@ -475,4 +513,4 @@ def sync_settings():
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_before_component(on_before_component)
|
script_callbacks.on_before_component(on_before_component)
|
||||||
script_callbacks.on_ui_settings(on_ui_settings)
|
script_callbacks.on_ui_settings(on_ui_settings)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue