mirror of https://github.com/Filexor/Clip_IO.git
commit
b6c2e9627d
15
README.md
15
README.md
|
|
@ -52,16 +52,23 @@ If "DirectiveOrder" is absent, it will be treated as order is 0.
|
|||
Local objects for eval are:
|
||||
i: torch.Tensor : input conditioning
|
||||
o: torch.Tensor : output conditioning
|
||||
c: dict : dict for carrying over
|
||||
g: dict : dict for carrying over
|
||||
p: modules.processing.StableDiffusionProcessing : [See source code of Stable diffusion Web UI.](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/baf6946e06249c5af9851c60171692c44ef633e0/modules/processing.py#L105)
|
||||
t: int : 0th dimension (token-wise) of index of input conditioning
|
||||
d: int : 1st dimension (dimension-wise) of index of input conditioning
|
||||
torch module and all objects in math module
|
||||
##### exec
|
||||
"exec" does component-wise python's exec.
|
||||
Local objects for exec are:
|
||||
Global objects for exec are:
|
||||
i: torch.Tensor : input conditioning
|
||||
o: torch.Tensor : output conditioning
|
||||
c: dict : dict for carrying over
|
||||
g: dict : dict for carrying over
|
||||
p: modules.processing.StableDiffusionProcessing : [See source code of Stable diffusion Web UI.](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/baf6946e06249c5af9851c60171692c44ef633e0/modules/processing.py#L105)
|
||||
torch module and all objects in math module
|
||||
**NOTE: If you want to change seed, change both p.seed: int and p.seeds: list[int] .**
|
||||
torch module and all objects in math module
|
||||
##### prompt
|
||||
"prompt" is prompt with some additional options.
|
||||
syntax example: `?prompt("""prompt must be triple single/double quoted""", clip_skip=2, no_padding=False)`
|
||||
arguments can be written as positional or omitted: `?prompt('''this also work''',,True)`
|
||||
prompt and clip_skip can be with list or tuple syntax: `?prompt(("This part of clip_skip is 2", 'Here is 1', "1 is broadcasted because clip_skip is exhausted"),[2,1])`
|
||||
`no_padding` does same behavior of `Don't add bos / eos / pad tokens` in Clip Output.
|
||||
|
|
@ -4,7 +4,7 @@ from collections import namedtuple
|
|||
from enum import IntEnum
|
||||
|
||||
import gradio
|
||||
import torch
|
||||
import torch as torch
|
||||
import lark
|
||||
import open_clip
|
||||
|
||||
|
|
@ -20,6 +20,7 @@ class Clip_IO(scripts.Script):
|
|||
mode_positive = "Disabled"
|
||||
mode_negative = "Disabled"
|
||||
conditioning_cache = {}
|
||||
global_carry = {}
|
||||
|
||||
evacuate_get_learned_conditioning = None
|
||||
evacuate_get_multicond_learned_conditioning = None
|
||||
|
|
@ -43,12 +44,16 @@ class Clip_IO(scripts.Script):
|
|||
mode_positive = gradio.Dropdown(["Disabled", "Simple", "Directive"], value = "Disabled", max_choices = 1, label = "Positive prompt mode")
|
||||
mode_negative = gradio.Dropdown(["Disabled", "Simple", "Directive"], value = "Disabled", max_choices = 1, label = "Negative prompt mode")
|
||||
pass
|
||||
with gradio.Accordion("Pre/Post-process", open = False):
|
||||
pre_batch_process = gradio.TextArea(max_lines=1024, label="Pre-batch-process")
|
||||
post_batch_process = gradio.TextArea(max_lines=1024, label="Post-batch-process-process")
|
||||
pass
|
||||
pass
|
||||
if not is_img2img:
|
||||
return [enabled, mode_positive, mode_negative]
|
||||
return [enabled, mode_positive, mode_negative, post_batch_process, pre_batch_process]
|
||||
pass
|
||||
else:
|
||||
return [enabled, mode_positive, mode_negative]
|
||||
return [enabled, mode_positive, mode_negative, post_batch_process, pre_batch_process]
|
||||
pass
|
||||
return []
|
||||
pass
|
||||
|
|
@ -148,6 +153,21 @@ class Clip_IO(scripts.Script):
|
|||
SPACE: /\s+/
|
||||
"""
|
||||
|
||||
syntax_directive_prompt = r"""
|
||||
start: (prompt | prompts) ("," (argument | arguments))* ("," (keyword_argument | keyword_arguments))*
|
||||
prompt: PROMPT
|
||||
prompts: "(" SPACE? PROMPT SPACE? ( "," SPACE? PROMPT SPACE? )* ")" | "[" SPACE? PROMPT SPACE? ( "," SPACE? PROMPT SPACE? )* "]"
|
||||
PROMPT: /"{3}/ /.*?/ /"{3}/ | /'{3}/ /.*?/ /'{3}/ | /"(?!"")[^"]*?"/ | /'(?!'')[^']*?'/
|
||||
argument: [ARGUMENT]
|
||||
arguments: "(" [ARGUMENT] ("," [ARGUMENT] )* ")" | "[" [ARGUMENT] ("," [ARGUMENT] )* "]"
|
||||
ARGUMENT: /[^()\[\]=,]+/
|
||||
keyword_argument: KEYWORD "=" VALUE
|
||||
keyword_arguments: KEYWORD "=" ( "(" [VALUE] ("," [VALUE] )* ")" | "[" [VALUE] ("," [VALUE] )* "]" )
|
||||
KEYWORD: /[^=,]+/
|
||||
VALUE: /[^()\[\]=,]+/
|
||||
SPACE: /\s+/
|
||||
"""
|
||||
|
||||
class Directive:
|
||||
class Names(IntEnum):
|
||||
eval
|
||||
|
|
@ -266,6 +286,10 @@ class Clip_IO(scripts.Script):
|
|||
pass
|
||||
|
||||
Process().transform(lark.Lark(Clip_IO.syntax_directive).parse(input))
|
||||
if len(conds) == 0:
|
||||
tmp = Clip_IO.FrozenCLIPEmbedderWithCustomWordsBase_forword(("a",), manual_chunk=False)
|
||||
conds.append(torch.zeros(0, tmp.shape[1]).to(devices.device))
|
||||
pass
|
||||
i = torch.vstack(conds)
|
||||
o = i.clone()
|
||||
c = {}
|
||||
|
|
@ -278,28 +302,161 @@ class Clip_IO(scripts.Script):
|
|||
try:
|
||||
for t in range(i.shape[0]):
|
||||
for d in range(i.shape[1]):
|
||||
local = {"i": i, "o": o, "c": c, "p": p, "t": t, "d": d, "torch": torch.__dict__} | math.__dict__
|
||||
local = {"i": i, "o": o, "g": Clip_IO.global_carry, "c": c, "p": p, "t": t, "d": d, "sd_model": shared.sd_model, "torch": torch.__dict__} | math.__dict__
|
||||
o[t, d] = eval(dir.inner, None, local)
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
except Exception as e:
|
||||
print(repr(e))
|
||||
o = i
|
||||
raise e
|
||||
pass
|
||||
finally:
|
||||
i = o.clone()
|
||||
pass
|
||||
elif dir.name == "exec":
|
||||
try:
|
||||
local = {"i": i, "o": o, "c": c, "p": p, "torch": torch.__dict__} | math.__dict__
|
||||
exec(dir.inner, None, local)
|
||||
globals = {"i": i, "o": o, "g": Clip_IO.global_carry, "c": c, "p": p, "sd_model": shared.sd_model, "torch": torch, "math": math}
|
||||
exec(dir.inner, globals, None)
|
||||
except Exception as e:
|
||||
print(repr(e))
|
||||
o = i
|
||||
raise e
|
||||
pass
|
||||
finally:
|
||||
i = local["o"].clone()
|
||||
i = globals["o"].clone()
|
||||
pass
|
||||
elif dir.name == "prompt":
|
||||
# prompt(prompt: str, clip_skip: int|None=None, no_padding=True)
|
||||
prompt: tuple[str]
|
||||
keyword_arguments: dict = {"clip_skip": [None], "no_padding": False}
|
||||
class prompt_transformer(lark.visitors.Transformer):
|
||||
keyword_position = 0
|
||||
|
||||
def prompt(self, token: list[lark.Token]):
|
||||
nonlocal prompt
|
||||
prompt = (token[0],)
|
||||
pass
|
||||
|
||||
def prompts(self, tokens: list[lark.Token]):
|
||||
nonlocal prompt
|
||||
prompt = tuple(tokens)
|
||||
pass
|
||||
|
||||
def PROMPT(self, token: lark.Token):
|
||||
if token.startswith('"""') and token.endswith('"""') or token.startswith("'''") and token.endswith("'''"):
|
||||
token = token[3:-3]
|
||||
pass
|
||||
elif token.startswith('"') and token.endswith('"') or token.startswith("'") and token.endswith("'"):
|
||||
token = token[1:-1]
|
||||
pass
|
||||
return token
|
||||
pass
|
||||
|
||||
def argument(self, token: list[lark.Token]):
|
||||
token = token[0]
|
||||
if token is None:
|
||||
self.keyword_position += 1
|
||||
return
|
||||
match self.keyword_position:
|
||||
case 0:
|
||||
if token.strip(" ").lower() == "none":
|
||||
keyword_arguments["clip_skip"] = [[None]]
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
keyword_arguments["clip_skip"] = [int(token.strip(" "))]
|
||||
pass
|
||||
except Exception:
|
||||
print(f'Given argument "{token}" is neither integer nor None.')
|
||||
pass
|
||||
self.keyword_position += 1
|
||||
pass
|
||||
case 1:
|
||||
if token.strip(" ").lower() == "true":
|
||||
value = True
|
||||
pass
|
||||
elif token.strip(" ").lower() == "false":
|
||||
value = False
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f'Given argument "{token}" is neither True nor False.')
|
||||
pass
|
||||
keyword_arguments["no_padding"] = [value]
|
||||
self.keyword_position += 1
|
||||
pass
|
||||
case _:
|
||||
self.keyword_position += 1
|
||||
pass
|
||||
pass
|
||||
|
||||
def arguments(self, token: list[lark.Token]):
|
||||
match self.keyword_position:
|
||||
case 0:
|
||||
keyword_arguments["clip_skip"] = []
|
||||
for token in token:
|
||||
if token.strip(" ").lower() == "none" or token.strip(" ") == "":
|
||||
keyword_arguments["clip_skip"].append(None)
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
keyword_arguments["clip_skip"].append(int(token.strip(" ")))
|
||||
pass
|
||||
except Exception:
|
||||
print(f'Given argument "{token}" is neither integer nor None.')
|
||||
pass
|
||||
self.keyword_position += 1
|
||||
pass
|
||||
pass
|
||||
case 1:
|
||||
keyword_arguments["no_padding"] = []
|
||||
for token in token:
|
||||
if token.strip(" ").lower() == "true":
|
||||
value = True
|
||||
pass
|
||||
elif token.strip(" ").lower() == "false":
|
||||
value = False
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f'Given argument "{token}" is neither True nor False.')
|
||||
pass
|
||||
keyword_arguments["no_padding"].append(value)
|
||||
self.keyword_position += 1
|
||||
pass
|
||||
pass
|
||||
case _:
|
||||
self.keyword_position += 1
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
|
||||
def keyword_argument(self, tree: lark.tree.Tree):
|
||||
keyword_arguments[tree.children[0].strip(" ")] = tree.children[1].strip(" ")
|
||||
pass
|
||||
pass
|
||||
prompt_transformer().transform(lark.Lark(Clip_IO.syntax_directive_prompt).parse(dir.inner))
|
||||
o = torch.vstack([o, Clip_IO.FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt, manual_chunk= keyword_arguments["no_padding"], clip_skips=[keyword_arguments["clip_skip"]])])
|
||||
i = o.clone()
|
||||
pass
|
||||
elif dir.name == "execfile":
|
||||
try:
|
||||
if dir.inner.startswith('"') and dir.inner.endswith('"') or dir.inner.startswith("'") and dir.inner.endswith("'"):
|
||||
dir.inner = dir.inner[1:-1]
|
||||
pass
|
||||
if not os.path.exists(dir.inner):
|
||||
dir.inner = os.path.join(os.path.dirname(__file__), "../program", dir.inner)
|
||||
if not os.path.exists(dir.inner):
|
||||
dir.inner = dir.inner + ".py"
|
||||
pass
|
||||
pass
|
||||
with open(dir.inner) as program:
|
||||
globals = {"i": i, "o": o, "g": Clip_IO.global_carry, "c": c, "p": p, "sd_model": shared.sd_model, "torch": torch, "math": math}
|
||||
exec(program, globals, None)
|
||||
pass
|
||||
pass
|
||||
except Exception as e:
|
||||
o = i
|
||||
raise e
|
||||
pass
|
||||
pass
|
||||
else:
|
||||
warnings.warn(f'Directive "{dir.name}" does not exist.')
|
||||
|
|
@ -316,7 +473,7 @@ class Clip_IO(scripts.Script):
|
|||
pass
|
||||
pass
|
||||
|
||||
def my_get_learned_conditioning(model, prompts, steps, p: processing.StableDiffusionProcessing = None, is_negative = True):
|
||||
def my_get_learned_conditioning(model, prompts: prompt_parser.SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False, p: processing.StableDiffusionProcessing = None, is_negative = True):
|
||||
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
||||
and the sampling step at which this condition is to be replaced by the next one.
|
||||
|
||||
|
|
@ -340,7 +497,7 @@ class Clip_IO(scripts.Script):
|
|||
prompt_schedules = [[[steps, prompt]] for prompt in prompts]
|
||||
pass
|
||||
else:
|
||||
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompts, steps)
|
||||
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
|
||||
pass
|
||||
|
||||
res = []
|
||||
|
|
@ -352,7 +509,7 @@ class Clip_IO(scripts.Script):
|
|||
res.append(cached)
|
||||
continue
|
||||
|
||||
texts: list[str] = [x[1] for x in prompt_schedule]
|
||||
texts = prompt_parser.SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
|
||||
if Clip_IO.enabled and (Clip_IO.mode_positive == "Simple" and not is_negative or Clip_IO.mode_negative == "Simple" and is_negative):
|
||||
conds = []
|
||||
for text in texts:
|
||||
|
|
@ -371,7 +528,15 @@ class Clip_IO(scripts.Script):
|
|||
|
||||
cond_schedule = []
|
||||
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
||||
cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, conds[i].to(devices.device)))
|
||||
if isinstance(conds, dict):
|
||||
cond = {k: v[i] for k, v in conds.items()}
|
||||
pass
|
||||
else:
|
||||
cond = conds[i]
|
||||
pass
|
||||
|
||||
cond_schedule.append(prompt_parser.ScheduledPromptConditioning(end_at_step, cond))
|
||||
pass
|
||||
|
||||
cache[prompt] = cond_schedule
|
||||
res.append(cond_schedule)
|
||||
|
|
@ -379,7 +544,7 @@ class Clip_IO(scripts.Script):
|
|||
return res
|
||||
pass
|
||||
|
||||
def my_get_multicond_learned_conditioning(model, prompts, steps, p: processing.StableDiffusionProcessing = None) -> prompt_parser.MulticondLearnedConditioning:
|
||||
def my_get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False, p: processing.StableDiffusionProcessing = None) -> prompt_parser.MulticondLearnedConditioning:
|
||||
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
||||
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
||||
|
||||
|
|
@ -388,11 +553,12 @@ class Clip_IO(scripts.Script):
|
|||
|
||||
res_indexes, prompt_flat_list, prompt_indexes = prompt_parser.get_multicond_prompt_list(prompts)
|
||||
|
||||
learned_conditioning = prompt_parser.get_learned_conditioning(model, prompt_flat_list, steps, p, is_negative = False)
|
||||
learned_conditioning = prompt_parser.get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling, p, is_negative = False)
|
||||
|
||||
res = []
|
||||
for indexes in res_indexes:
|
||||
res.append([prompt_parser.ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
|
||||
pass
|
||||
|
||||
return prompt_parser.MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
|
||||
pass
|
||||
|
|
@ -474,7 +640,7 @@ class Clip_IO(scripts.Script):
|
|||
pass
|
||||
|
||||
def get_my_get_conds_with_caching(p: processing.StableDiffusionProcessing):
|
||||
def my_get_conds_with_caching(function, required_prompts, steps, caches, extra_network_data):
|
||||
def my_get_conds_with_caching(function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
|
||||
"""
|
||||
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
||||
using a cache to store the result if the same arguments have been used before.
|
||||
|
|
@ -486,18 +652,29 @@ class Clip_IO(scripts.Script):
|
|||
|
||||
caches is a list with items described above.
|
||||
"""
|
||||
if shared.opts.use_old_scheduling:
|
||||
old_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, False)
|
||||
new_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, True)
|
||||
if old_schedules != new_schedules:
|
||||
p.extra_generation_params["Old prompt editing timelines"] = True
|
||||
pass
|
||||
pass
|
||||
|
||||
cached_params = p.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling)
|
||||
|
||||
for cache in caches:
|
||||
if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
|
||||
if cache[0] is not None and cached_params == cache[0] and not Clip_IO.enabled:
|
||||
return cache[1]
|
||||
|
||||
cache = caches[0]
|
||||
|
||||
with devices.autocast():
|
||||
cache[1] = function(shared.sd_model, required_prompts, steps, p)
|
||||
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
|
||||
|
||||
cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
|
||||
cache[0] = cached_params
|
||||
return cache[1]
|
||||
pass
|
||||
|
||||
return my_get_conds_with_caching
|
||||
|
||||
def get_inner_function(outer, new_inner):
|
||||
|
|
@ -596,7 +773,6 @@ class Clip_IO(scripts.Script):
|
|||
prompt_parser.get_learned_conditioning = Clip_IO.my_get_learned_conditioning
|
||||
prompt_parser.get_multicond_learned_conditioning = Clip_IO.my_get_multicond_learned_conditioning
|
||||
#Clip_IO.replace_inner_function(processing.process_images_inner, Clip_IO.get_my_get_conds_with_caching())
|
||||
pass
|
||||
else:
|
||||
Clip_IO.enabled = False
|
||||
pass
|
||||
|
|
@ -620,38 +796,75 @@ class Clip_IO(scripts.Script):
|
|||
Clip_IO.evacuate_get_conds_with_caching = p.get_conds_with_caching
|
||||
p.get_conds_with_caching = Clip_IO.get_my_get_conds_with_caching(p)
|
||||
pass
|
||||
try:
|
||||
globals = {"g": Clip_IO.global_carry, "p": p, "sd_model": shared.sd_model, "torch": torch, "math": math}
|
||||
exec(args[4], globals, None)
|
||||
except Exception as e:
|
||||
raise e
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
|
||||
def postprocess_batch(self, p: processing.StableDiffusionProcessing, *args, **kwargs):
|
||||
if Clip_IO.enabled:
|
||||
Clip_IO.mode_positive = "Disabled"
|
||||
Clip_IO.mode_negative = "Disabled"
|
||||
if getattr(p, "get_conds_with_caching", None) is not None:
|
||||
p.get_conds_with_caching = Clip_IO.evacuate_get_conds_with_caching
|
||||
try:
|
||||
globals = {"g": Clip_IO.global_carry, "p": p, "sd_model": shared.sd_model, "torch": torch, "math": math}
|
||||
exec(args[3], globals, None)
|
||||
except Exception as e:
|
||||
raise e
|
||||
pass
|
||||
finally:
|
||||
Clip_IO.mode_positive = "Disabled"
|
||||
Clip_IO.mode_negative = "Disabled"
|
||||
if getattr(p, "get_conds_with_caching", None) is not None:
|
||||
p.get_conds_with_caching = Clip_IO.evacuate_get_conds_with_caching
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
|
||||
def tokenize_line_manual_chunk(prompt: str, clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords) -> list[PromptChunk]:
|
||||
if opts.enable_emphasis:
|
||||
parsed = prompt_parser.parse_prompt_attention(prompt)
|
||||
pass
|
||||
else:
|
||||
parsed = [[prompt, 1.0]]
|
||||
def tokenize_line_manual_chunk(prompts: tuple[str, ...], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> tuple[list[PromptChunk],list[tuple[int,int]]]:
|
||||
parsed = []
|
||||
for prompt in prompts:
|
||||
parsed.append(["SEPARATION", -1])
|
||||
if opts.enable_emphasis:
|
||||
to_appends = prompt_parser.parse_prompt_attention(prompt)
|
||||
for to_append in to_appends:
|
||||
parsed.append(to_append)
|
||||
pass
|
||||
else:
|
||||
parsed.append([prompt, 1.0])
|
||||
pass
|
||||
pass
|
||||
|
||||
tokenized = clip.tokenize([text for text, _ in parsed])
|
||||
|
||||
chunks: list[PromptChunk] = []
|
||||
chunk = PromptChunk()
|
||||
token_count = 0
|
||||
last_comma = -1
|
||||
separation_starts: list[tuple[int,int]] = []
|
||||
|
||||
def next_chunk(is_last=False):
|
||||
nonlocal token_count
|
||||
nonlocal last_comma
|
||||
nonlocal chunk
|
||||
|
||||
# We don't have to fill the chunk.
|
||||
if not manual_chunk:
|
||||
if is_last:
|
||||
token_count += len(chunk.tokens)
|
||||
else:
|
||||
token_count += clip.chunk_length
|
||||
|
||||
to_add = clip.chunk_length - len(chunk.tokens)
|
||||
if to_add > 0:
|
||||
chunk.tokens += [clip.id_end] * to_add
|
||||
chunk.multipliers += [1.0] * to_add
|
||||
|
||||
chunk.tokens = [clip.id_start] + chunk.tokens + [clip.id_end]
|
||||
chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
|
||||
pass
|
||||
|
||||
last_comma = -1
|
||||
chunks.append(chunk)
|
||||
|
|
@ -659,6 +872,11 @@ class Clip_IO(scripts.Script):
|
|||
pass
|
||||
|
||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||
if text == 'SEPARATION' and weight == -1:
|
||||
separation_starts.append((len(chunks), len(chunk.tokens)))
|
||||
continue
|
||||
pass
|
||||
|
||||
if text == 'BREAK' and weight == -1:
|
||||
next_chunk()
|
||||
continue
|
||||
|
|
@ -674,7 +892,7 @@ class Clip_IO(scripts.Script):
|
|||
|
||||
# this is when we are at the end of alloted 77 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
|
||||
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
|
||||
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == clip.chunk_length + 2 and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||
elif manual_chunk and opts.comma_padding_backtrack != 0 and len(chunk.tokens) == clip.chunk_length + 2 and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||
break_location = last_comma + 1
|
||||
|
||||
reloc_tokens = chunk.tokens[break_location:]
|
||||
|
|
@ -688,7 +906,25 @@ class Clip_IO(scripts.Script):
|
|||
chunk.multipliers = reloc_mults
|
||||
pass
|
||||
|
||||
if len(chunk.tokens) == clip.chunk_length + 2:
|
||||
elif not manual_chunk and opts.comma_padding_backtrack != 0 and len(chunk.tokens) == clip.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||
break_location = last_comma + 1
|
||||
|
||||
reloc_tokens = chunk.tokens[break_location:]
|
||||
reloc_mults = chunk.multipliers[break_location:]
|
||||
|
||||
chunk.tokens = chunk.tokens[:break_location]
|
||||
chunk.multipliers = chunk.multipliers[:break_location]
|
||||
|
||||
next_chunk()
|
||||
chunk.tokens = reloc_tokens
|
||||
chunk.multipliers = reloc_mults
|
||||
pass
|
||||
|
||||
if manual_chunk and len(chunk.tokens) == clip.chunk_length + 2:
|
||||
next_chunk()
|
||||
pass
|
||||
|
||||
elif not manual_chunk and len(chunk.tokens) == clip.chunk_length:
|
||||
next_chunk()
|
||||
pass
|
||||
|
||||
|
|
@ -718,19 +954,21 @@ class Clip_IO(scripts.Script):
|
|||
next_chunk(is_last=True)
|
||||
pass
|
||||
|
||||
return chunks
|
||||
return chunks, separation_starts
|
||||
|
||||
def process_texts_manual_chunk(prompts: list[str], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords) -> list[list[PromptChunk]]:
|
||||
def process_texts_manual_chunk(prompts: list[tuple[str, ...]], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> tuple[list[list[PromptChunk]],list[list[tuple[int,int]]]]:
|
||||
cache = {}
|
||||
batch_chunks: list[list[PromptChunk]] = []
|
||||
separation_starts_list: list[list[tuple[int,int]]] = []
|
||||
for prompt in prompts:
|
||||
if prompt in cache:
|
||||
chunks = cache[prompt]
|
||||
else:
|
||||
chunks = Clip_IO.tokenize_line_manual_chunk(prompt, clip)
|
||||
chunks, separation_starts = Clip_IO.tokenize_line_manual_chunk(prompt, clip, manual_chunk)
|
||||
cache[prompt] = chunks
|
||||
|
||||
batch_chunks.append(chunks)
|
||||
separation_starts_list.append(separation_starts)
|
||||
|
||||
if False:
|
||||
# We have to ensure all chunk in batch_chunks have same length.
|
||||
|
|
@ -755,14 +993,18 @@ class Clip_IO(scripts.Script):
|
|||
pass
|
||||
pass
|
||||
|
||||
return batch_chunks
|
||||
return batch_chunks, separation_starts_list
|
||||
|
||||
def get_chunks(prompt: str, clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> list[list[PromptChunk]]:
|
||||
def get_chunks(prompt: str | tuple[str, ...], clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords, manual_chunk: bool) -> tuple[list[list[PromptChunk]],list[list[tuple[int,int]]]]:
|
||||
"Return: PromptChunks, token separation position"
|
||||
if isinstance(prompt, str):
|
||||
prompt = (prompt,)
|
||||
pass
|
||||
if opts.use_old_emphasis_implementation:
|
||||
raise NotImplementedError
|
||||
pass
|
||||
if manual_chunk:
|
||||
return Clip_IO.process_texts_manual_chunk([prompt], clip)
|
||||
if True: # manual_chunk:
|
||||
return Clip_IO.process_texts_manual_chunk([prompt], clip, manual_chunk)
|
||||
pass
|
||||
else:
|
||||
batch_chunks, _ = clip.process_texts([prompt])
|
||||
|
|
@ -820,7 +1062,7 @@ class Clip_IO(scripts.Script):
|
|||
def on_save_embeddings_as_pt(prompt: str, filename: str, overwrite: bool, manual_chunk: bool):
|
||||
try:
|
||||
clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model
|
||||
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
|
||||
batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk)
|
||||
embeddings: torch.Tensor = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk)
|
||||
|
||||
filename = os.path.join(os.path.dirname(__file__), "../conditioning", filename)
|
||||
|
|
@ -847,7 +1089,7 @@ class Clip_IO(scripts.Script):
|
|||
def on_save_embeddings_as_csv(prompt: str, filename: str, transpose: bool, add_token: bool, overwrite: bool, manual_chunk: bool):
|
||||
try:
|
||||
clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model
|
||||
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
|
||||
batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk)
|
||||
embeddings, tokens = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk)
|
||||
|
||||
embeddings: list[list[str]] = embeddings[0].tolist()
|
||||
|
|
@ -894,7 +1136,7 @@ class Clip_IO(scripts.Script):
|
|||
try:
|
||||
with devices.autocast():
|
||||
clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model
|
||||
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
|
||||
batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk)
|
||||
chunk_count = max([len(x) for x in batch_chunks])
|
||||
zs = []
|
||||
for i in range(chunk_count):
|
||||
|
|
@ -939,11 +1181,92 @@ class Clip_IO(scripts.Script):
|
|||
return f'File {filename} is successfully saved. {datetime.datetime.now().isoformat()}'
|
||||
pass
|
||||
|
||||
def encode_with_transformers(clip, tokens, chunk_count: int, clip_skips: list[list[int | None]], separation_starts_list:list[list[tuple[int,int]]]):
|
||||
clip.wrapped.transformer.to(shared.device)
|
||||
outputs = clip.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||
|
||||
zs = []
|
||||
for z in outputs.hidden_states:
|
||||
zs.append(clip.wrapped.transformer.text_model.final_layer_norm(z))
|
||||
pass
|
||||
zs[-1] = outputs.last_hidden_state
|
||||
|
||||
if opts.CLIP_stop_at_last_layers > 1:
|
||||
zo = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||
zo = clip.wrapped.transformer.text_model.final_layer_norm(zo)
|
||||
else:
|
||||
zo = outputs.last_hidden_state
|
||||
|
||||
for i in range(len(tokens)):
|
||||
chunk_start = separation_starts_list[i][0][0]
|
||||
chunk_end = separation_starts_list[i][1][0] if len(separation_starts_list[i]) >= 2 else chunk_count + 1
|
||||
while True:
|
||||
while chunk_count < chunk_start:
|
||||
separation_starts_list[i].pop(0)
|
||||
clip_skips[i].pop(0)
|
||||
pass
|
||||
|
||||
if chunk_count == chunk_start:
|
||||
token_start = separation_starts_list[i][0][1]
|
||||
pass
|
||||
elif chunk_count > chunk_start:
|
||||
token_start = 0
|
||||
pass
|
||||
|
||||
if chunk_count < chunk_end:
|
||||
token_end = zo.shape[1]
|
||||
pass
|
||||
elif chunk_count == chunk_end:
|
||||
token_end = separation_starts_list[i][1][1] if len(separation_starts_list[i]) >= 2 else zo.shape[1]
|
||||
pass
|
||||
|
||||
zo[i,token_start:token_end,:] = zs[-clip_skips[i][0]][i,token_start:token_end,:] if clip_skips[i][0] is not None else zo[i,token_start:token_end,:]
|
||||
if token_end == zo.shape[1]:
|
||||
break
|
||||
pass
|
||||
if chunk_count == chunk_end:
|
||||
separation_starts_list[i].pop(0)
|
||||
if len(clip_skips[i]) >=2:
|
||||
clip_skips[i].pop(0)
|
||||
pass
|
||||
else:
|
||||
clip_skips[i][0] = None
|
||||
pass
|
||||
pass
|
||||
pass
|
||||
|
||||
return zo
|
||||
|
||||
def FrozenCLIPEmbedderWithCustomWordsBase_forword(prompt: str | tuple[str, ...], clip = shared.sd_model.cond_stage_model, manual_chunk = False, clip_skips: list[list[int | None]] = [[None]]) -> torch.Tensor:
|
||||
batch_chunks, separation_starts_list = Clip_IO.get_chunks(prompt, clip, manual_chunk)
|
||||
chunk_count = max([len(x) for x in batch_chunks])
|
||||
zs = []
|
||||
for i in range(chunk_count):
|
||||
batch_chunk = [chunks[i] if i < len(chunks) else clip.empty_chunk() for chunks in batch_chunks]
|
||||
remade_batch_tokens = [x.tokens for x in batch_chunk]
|
||||
tokens = torch.asarray([x.tokens for x in batch_chunk]).to(devices.device)
|
||||
clip.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||
|
||||
if clip.id_end != clip.id_pad and not manual_chunk:
|
||||
for batch_pos in range(len(remade_batch_tokens)):
|
||||
index = remade_batch_tokens[batch_pos].index(clip.id_end)
|
||||
tokens[batch_pos, index+1:tokens.shape[1]] = clip.id_pad
|
||||
|
||||
z = Clip_IO.encode_with_transformers(clip, tokens, i, clip_skips, separation_starts_list) if isinstance(clip, FrozenCLIPEmbedderWithCustomWords) else clip.encode_with_transformers(tokens)
|
||||
if True: # if not no_emphasis:
|
||||
batch_multipliers = torch.asarray([x.multipliers for x in batch_chunk]).to(devices.device)
|
||||
original_mean = z.mean()
|
||||
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z = z * (original_mean / new_mean) # z = z * (original_mean / new_mean) if not no_norm else z
|
||||
zs.append(z[0])
|
||||
return torch.vstack(zs)
|
||||
|
||||
def on_save_conditioning_as_csv(prompt: str, filename: str, transpose: bool, no_emphasis: bool, no_norm: bool, add_token: bool, overwrite: bool, manual_chunk: bool):
|
||||
try:
|
||||
with devices.autocast():
|
||||
clip: FrozenCLIPEmbedderWithCustomWords | FrozenOpenCLIPEmbedderWithCustomWords = shared.sd_model.cond_stage_model
|
||||
batch_chunks = Clip_IO.get_chunks(prompt, clip, manual_chunk)
|
||||
batch_chunks, _ = Clip_IO.get_chunks(prompt, clip, manual_chunk)
|
||||
_, token_list = Clip_IO.get_flat_embeddings(batch_chunks, clip, manual_chunk)
|
||||
chunk_count = max([len(x) for x in batch_chunks])
|
||||
zs = []
|
||||
|
|
|
|||
Loading…
Reference in New Issue