1735 lines
78 KiB
Python
1735 lines
78 KiB
Python
# stable-diffusion-webui-embedding-merge
|
||
|
||
'''
|
||
WebUI Dependencies:
|
||
|
||
1) Class <modules.textual_inversion.textual_inversion.Embedding> is used to create embeddings.
|
||
required fields: <.vec> = actual tensor, <.vectors> = first dim size, <.shape> = last dim size.
|
||
2) Object <modules.sd_hijack.model_hijack.embedding_db> is abused to create ephemeral embeddings.
|
||
Work with fields <.word_embeddings> and <.ids_lookup> is replicated from
|
||
</modules/textual_inversion/textual_inversion.py>, refer to <register_embedding()> here.
|
||
UPD: not needed anymore, since upstream implemented <register_embedding_by_name()>
|
||
3) Saving of embeddings is done by crafting a proper shape for .pt file manually, and then
|
||
<modules.sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload)> is called.
|
||
4) <modules.sd_hijack.StableDiffusionModelHijack.get_prompt_lengths(text)> is hooked but not replaced.
|
||
5) Part of <encode_embedding_init_text()> from <sd_hijack_clip.py> and <sd_hijack_open_clip.py> is converted to
|
||
<tokens_to_vectors()> here; it uses <shared.sd_model.cond_stage_model.wrapped> and then calls either
|
||
<.model.token_embedding.wrapped()> for SD2, or <.transformer.text_model.embeddings.token_embedding.wrapped()> for SD1
|
||
6) Code from <https://github.com/AUTOMATIC1111/stable-diffusion-webui-tokenizer> is heavily copied:
|
||
it grabs <shared.sd_model.cond_stage_model.wrapped> and checks it against
|
||
<FrozenCLIPEmbedder> and <FrozenOpenCLIPEmbedder>, refer to <tokens_to_text()> here.
|
||
7) <shared.sd_model.cond_stage_model.tokenize_line(line)> is called many times when parsing prompts.
|
||
The code is very dependent on what it returns! Source in </modules/sd_hijack_clip.py>
|
||
Also <shared.sd_model.cond_stage_model.tokenize()> can be called.
|
||
8) Method <p.cached_params()> is faked to be always unique if any runtime embeddings are detected to prevent wrong caching.
|
||
'''
|
||
|
||
import re
|
||
import os
|
||
import torch
|
||
import json
|
||
import html
|
||
import time
|
||
import types
|
||
import traceback
|
||
import threading
|
||
import gradio
|
||
import modules
|
||
from modules import shared, scripts, script_callbacks, devices, processing, sd_models
|
||
from modules.shared import opts, cmd_opts
|
||
from modules.textual_inversion.textual_inversion import Embedding
|
||
import open_clip.tokenizer
|
||
|
||
def _webui_embedding_merge_():
|
||
|
||
class Exception_From_EmbeddingMergeExtension(Exception):
|
||
pass
|
||
class Exception_From_EmbeddingMergeExtension_():
|
||
def __init__(self,_):
|
||
self._ = _
|
||
def __getattr__(self,_):
|
||
raise Exception_From_EmbeddingMergeExtension(self._)
|
||
|
||
def gr_tab():
|
||
with gradio.Blocks(analytics_enabled=False) as block:
|
||
gradio.HTML('<style>#tab_embedding_merge_extension p::before,#tab_embedding_merge_extension p::after,#tab_embedding_merge_extension code::before,#tab_embedding_merge_extension code::after{display:none!important}</style>')
|
||
with gradio.Row():
|
||
with gradio.Accordion('Embedding Merge extension! (Click here for usage instructions)', open=False):
|
||
with gradio.Accordion('Introduction...', open=False):
|
||
gradio.Markdown('''
|
||
## Purpose:
|
||
|
||
Did you know that StableDiffusion reads your prompt by so-called tokens? They are multidimensional numerical vectors that construct together words and phrases.
|
||
|
||
It is actually possible to create new words by simple merging (adding) different vectors together, resulting in something that could mean both things simultaneously!
|
||
|
||
However, it is not always working, and sometimes it won't give what you would expect, but it is definitely worth experimenting.
|
||
|
||
Basically, this extension will create Textual Inversion embeddings purely by token merging (without any training on actual images!) either automatically during generation, or manually on its tab.
|
||
|
||
## Usage:
|
||
|
||
The tab `EM` can be used to:
|
||
- inspect your prompt or specific words
|
||
- create TI embeddings from text fragments with or without merging
|
||
- check correctness of your merge expressions
|
||
''')
|
||
gradio.Markdown('''
|
||
### TL;DR:
|
||
|
||
Use syntax `<'one thing'+'another thing'>` to merge terms "one thing" and "another thing" together in one single embedding in your positive or negative prompts at runtime.
|
||
|
||
Also use `<'your words'*0.5>` (or any number, default is 1.0) to increase or decrease the essence of "your words" (which can be even zero to disable that part of the prompt).
|
||
|
||
To use attention with round brackets ( ), put them around < >, like `(<'one'+'two'>:0.9)`
|
||
Use as many <> in one prompt, as you want; also you can put your existing TI embedding names inside `' '`.
|
||
|
||
~~When you need literal <' for some reason, put a space between.~~ You cannot have literal <' anywhere in your prompts; but with a space between (`< '`) it will be ignored by this extension.
|
||
|
||
If some other extension interferes with this syntax, change angular brackets to curly: `{'also works'*4}`
|
||
|
||
## View text or embeddings vectors
|
||
|
||
You can paste your vanilla prompt (without any other special syntax) into the textbox in EM tab to see how it is parsed by WebUI. All of detected Textual Inversion embeddings will be extracted and presented to you along with literal text tokens. For example:
|
||
|
||
>intergalactic train, masterpiece, by Danh Víµ
|
||
''')
|
||
with gradio.Accordion('More about table columns and grouping of its rows...', open=False):
|
||
gradio.Markdown('''
|
||
### Rows:
|
||
|
||
- `By none` = interpret the prompt as a whole, extracting all characters from real tokens
|
||
- `By comma` = split the prompt by tags on commas, removing commas but keeping source space characters
|
||
- `By parts` (default) = split at TI embeddings, joining text parts together, keeping spaces
|
||
- `By words` = split only after tokens that actually produce space character at the end
|
||
- `By tokens` = split at everything except characters that are represented with more than one vector
|
||
- `By vectors` = show all vectors separated, even for TI embeddings
|
||
|
||
### Columns:
|
||
|
||
- `Index` = index of one vector or index range (inclusive) for this row
|
||
- `Vectors` = number of final vectors for this row (to clearly see it)
|
||
- `Text` = original or recreated from tokens text, enclosed in quotes for clarity
|
||
- `Token` = list of CLIP token numbers that represent this row; for TI embeddings \* or \*_X where X is the index of current embedding vector
|
||
- `Min` = lowest (negative) value of the vector or grouped vectors values
|
||
- `Max` = largest value
|
||
- `Sum` = sum of all values with sign
|
||
- `Abs` = sum of modulus of each value, without sign (always positive)
|
||
- `Len` = vector length in L2 norm, square root of sum of squared values (computed approximate)
|
||
- `Std` = standard deviation for vector values.
|
||
|
||
### Why do you need it:
|
||
|
||
To make sure your prompt is interpreted the way you expect (for example, that existing TI embeddings are detected). Also you can explore CLIP tokens this way.
|
||
|
||
If you type a new name into the textbox on the bottom, your whole current prompt will be converted into a single Textual Inversion embedding with that name (and stored inside `/embeddings/embedding_merge/` subdirectory). You can use this for:
|
||
|
||
- Creating a shortened part to quickly use in prompts (not recommended though, since you will lose the original text later), but with no other benefits;
|
||
- Prepare TI embedding for actual training by using existing embeddings for its initialization.
|
||
''')
|
||
gradio.Markdown('''
|
||
## Test merge expression:
|
||
|
||
In EM tab you can enter a "merge expression" that starts with a single quote, to see how it will be parsed and combined by this extension. It should contain single quotes around literal texts or TI embeddings, and special operators between them. For example:
|
||
|
||
>'greg rutkowski'/4+'gustav dore'*0.75
|
||
''')
|
||
with gradio.Accordion('More about merge expression syntax...', open=False):
|
||
gradio.Markdown('''
|
||
### Expression syntax:
|
||
|
||
- `'one' + 'two'` = blend vectors together by simple sum of all values. If length is different, smallest part will be right-padded with zeroes.
|
||
- `'one' - 'two'` = as above, but subtraction. Note that + and - can be put only between textual parts and will have lowest priority.
|
||
- `'text' * NUM` = multiply all vectors of quoted literal by numeric value. You can use floating point (0.85) and negative numbers (-1), but not arithmetic expressions.
|
||
- `'text' / NUM` = division by number, just as multiplication above. Applies to previous text literal but after previous similar operations, so you can multiply and divide together (\*3/5)
|
||
- `'text' : NUM` = change vector count of literal, to shrink or enlarge (padded with zeros). Only integer without sign!
|
||
- `'text' :+ NUM` and `'text' :- NUM` = circular rotate vectors in this token, for example +1 will shift index of each vector by one forward, wrapping on last.
|
||
- `'text',NUM` (chainable as `'a',B,'c','d',E,F…`) = concatenate text with a token by its numerical index (so, to get any pure token – use empty left string: `'',256`). Special tokens: `0000` = "start token" (index 49406), `000` = "end token" (index 49407), `00` = "padding token" (also 49407 for SD1, but 0 for SD2). Token number `0` is not zero-vector, but for some reason counts as symbol "!" without a space after it, which is impossible to normally enter anyway.
|
||
|
||
To apply multiplication (or division), cropping or shifting **to the result** of addition (or subtraction), you cannot use parenthesis; instead, try this syntax:
|
||
|
||
- `'one' + 'two' =* NUM` = will multiply the sum of 'one' and 'two', but not 'two' alone
|
||
- `'one' + 'two' =/ NUM` = divide the sum (or any number of sums to the left), effectively the "result" of everything
|
||
- `'one' + 'two' =: NUM` = crop or enlarge the results
|
||
- `'one' + 'two' =:+ NUM` or `'one' + 'two' =:- NUM` = rotate the result
|
||
|
||
Thus, the following operations are doing the same:
|
||
|
||
>`'a'/2 + 'b'/2 + '':1 - 'd'`
|
||
`'a'+'b' =* 0.5 + 'c'*0 + 'd'*-1`
|
||
|
||
There is no true "concatenation" operator (since you will be able to concatenate several separate merge expressions later), but you may replicate it with addition of the same text enlarged and shifted, if you need.
|
||
Operation "," has the highest priority (it will directly construct the string before doing anything else), so you cannot concatenate anything to the result of addition or multiplication. Use it only to add tokens by index in your text.
|
||
|
||
For example, repeating a two-vector word, resulting in 4 vectors of two equal pairs:
|
||
|
||
> 'artstation' + 'artstation' :4 :+2
|
||
> 'artstation','artstation'
|
||
|
||
You can use shifting to join several vectors of the same text together. For example, given a 4-vectors word you may merge those vectors in one:
|
||
|
||
> 'kuvshinov' + 'kuvshinov':-1 + 'kuvshinov':-2 + 'kuvshinov':-3 =: 1
|
||
> '',1836 + '',85 + '',43074 + '',341
|
||
|
||
Note that those indices are referring to "ku|v|shino|v[space]" and cannot be entered from raw text, since it would be parsed as "ku[space]", "v[space]" and "shino[space]", which are different tokens!
|
||
|
||
When you merge strings of unequal length, shortest one is padded with zero vectors; if you want to pad it with something else, you should check the vector count and concatenate accordingly:
|
||
|
||
> 'close-up',00,00 + 'out-of-frame' + 'cropped',00,00,00,00
|
||
> 'up',00,00+'of-frame'+'',00,00,00 =:5:+2 + 'close-'+'out-'+'cropped',00
|
||
|
||
### Why do you need it:
|
||
|
||
To prepare your expression and fix any errors. You can evaluate its correctness by roughly comparing numbers in table (for example, adding vectors will generally result in higher `Abs` value; while multiplication is directly changing all numbers straightforwardly).
|
||
|
||
If for some reason you couldn't use the syntax for merging prompts at runtime, at least you will be able to enter a name and create a regular TI embedding from your merge expression. Then you may use it even without this extension installed!
|
||
|
||
Also you can check numerical parameters of your trained textual embedding and compare it with "normal" vectors. For example, very large `Len` or `Std` will mean that something is wrong and at least you may divide it in attempt to fix.
|
||
''')
|
||
gradio.Markdown('''
|
||
## Several merge expressions in prompt:
|
||
|
||
If you put a valid merge expression enclosed in angular <'…' …> or curly {'…' …} brackets anywhere in your prompt (with no space between `<` or `{` and `'`) on EM tab, it will be parsed and merged into one temporary Textual Inversion embedding, which replaces the expression itself. The resulting prompt will be joined from those embeddings and anything between expressions. For example:
|
||
|
||
>A photo of <'cat'+'dog'>, {'4k'+'dynamic lighting'+'science fiction'=/3} masterpiece
|
||
''')
|
||
with gradio.Accordion('More examples of using angular/curly brackets...', open=False):
|
||
gradio.Markdown('''
|
||
### More examples:
|
||
|
||
|
||
Combining different subjects or styles together, resulting in joined concepts:
|
||
|
||
> A realistic photo of the <'girl'+'doll'> in rainbow dress standing on a shore.
|
||
Art by <'greg rutkowski'*X+'hayao miyazaki'*Y> style.
|
||
|
||
Notes:
|
||
- Works best when all of your subjects have the same number of vectors (also can be roughly simulated by BREAK statement: `… photo of the girl in rainbow … BREAK … photo of the doll in rainbow …`);
|
||
- You don't have to divide on the number of added parts, especially if your subjects are very different (e.g. not contain same tokens);
|
||
- By multiplying each part in second example (where X and Y are numbers between 0.0 and 1.0) you may get a weighed combination or interpolation.
|
||
|
||
Changing weight of individual words in prompt:
|
||
|
||
> A <'peacock'*X> is standing on a top of <'giraffe'*Y>.
|
||
worst quality, ugly, <'bad anatomy,':0> blurry, cropped
|
||
|
||
Where X and Y will be numbers from 0.0 to 1.0 or even higher, up to 5. This way you can directly change relative affection between subjects.
|
||
|
||
Notes:
|
||
- Often values between 0.5 and 1.5 don't really change anything, looking like plain 1.0
|
||
- Values lower than 0.5 and near to 0.0 are greatly reducing subject weight indeed! Up to its complete absence (which is not possible otherwise, for example even zero attention `(word:0)` does not eliminate "word" from the prompt)
|
||
- High numbers might increase the presence of an object, not in quantity but in essence. Very high multipliers (above 10) corrupt the subject, but still don't destroy the image itself.
|
||
|
||
Eliminating a part of the negative prompt by zeroing its vectors can be used to understand the effect of the part in question, without shifting the rest of the text otherwise. Since WebUI is splitting long prompts at arbitrary commas (and then merging resulting parts together), simple deletion of a part might change things severely.
|
||
''')
|
||
gradio.Markdown('''
|
||
## Using merge expressions in prompts at runtime!
|
||
|
||
You can actually put merge expressions in angular or curly brackets into your txt2img or img2img prompt in WebUI. This extension will intercept both main and negative prompts, parse and merge expressions creating temporary TI embeddings that WebUI will "see" instead of your original text. In generation info there will be internal meaningless names like <'EM_1'>, but extra parameter "EmbeddingMerge" will contain original merge expressions. To quickly restore your prompts, just paste your complete generation information (from .txt or PNG Info) into the textbox on EM tab (also it should work for the official "paste" toolbar button too) – its temporary embeddings will be replaced back with expressions, for example:
|
||
|
||
> a photo of <'EM_1'>
|
||
Negative prompt: {'EM_2'}
|
||
Steps: 8, Sampler: DPM++ 2M Karras, CFG scale: 7, Seed: 1374372309, Size: 512x512, Model hash: c6bbc15e32, Model: sd-v1-5-inpainting, EmbeddingMerge: "<'EM_1'>=<'sky' * 2/4 + 'forest' * 3/4>, {'EM_2'}={'blurry'+'cropped'}", Conditional mask weight: 1
|
||
|
||
For your information replicating start tokens of the syntax itself:
|
||
- `<'` = `<'',27,6>` or `<'',27,262>`
|
||
- `{'` = `<'',90,6>` or `<'',90,262>`
|
||
|
||
''')
|
||
with gradio.Accordion('Limitations...', open=False):
|
||
gradio.Markdown('''
|
||
### What is not working:
|
||
|
||
#### Binding properties to objects:
|
||
|
||
> Photo of a <'blonde'+'boy'> in <'red'+'shirt'> wearing <'green'+'pants'> and <'blue'+'shoes'>
|
||
|
||
– results in anything but not what was requested.
|
||
|
||
#### Collapsing artists to single token:
|
||
|
||
> Painting by <'William' + '-' + 'Adolphe'+'Adolphe':+1 + 'Bouguereau'+'Bouguereau':+1+'Bouguereau':+2 =:1>. A girl, masterpiece
|
||
|
||
– results in something barely distinct from zeroing the term altogether.
|
||
|
||
#### Subtracting concepts as in word2vec:
|
||
|
||
> Full-body photo of a <'king'-'man'+'woman'>
|
||
Detailed photo of <'yellow'-'red'> car
|
||
|
||
– generally results in totally ruined composition.
|
||
|
||
#### Simulating negative prompt via negation of words:
|
||
|
||
> A portrait of the princess. <'frame, black-white'*-1>
|
||
A cat is chasing a dog. <''-'road'-'grass'>
|
||
|
||
– will still add those concepts to positive prompt, but with weird presence. You could find more luck with small values `-0.1-0.0` though.
|
||
''')
|
||
with gradio.Row():
|
||
gr_text = gradio.Textbox(value='', lines=4, max_lines=16, interactive=True, label='Your prompt (no weight/attention, do not escape parenthesis/brackets); or your merge expression (if the first character is a single quote); or a generation info to restore prompts')
|
||
with gradio.Row():
|
||
with gradio.Column(scale=1):
|
||
gr_button = gradio.Button('Parse!',variant='primary')
|
||
with gradio.Column(scale=3):
|
||
gr_radio = gradio.Radio(choices=('By none','By comma','By parts','By words','By tokens','By vectors'), value='By parts', type='index', interactive=True, label='Group/split table by: (when not started with single quote - so only for prompts, not for merge)')
|
||
with gradio.Box():
|
||
gr_html = gradio.HTML(label='out')
|
||
with gradio.Row():
|
||
gr_true = gradio.Checkbox(value=True,visible=False,show_label=False)
|
||
gr_false = gradio.Checkbox(value=False,visible=False,show_label=False)
|
||
gr_name = gradio.Textbox(value='', lines=1, max_lines=1, interactive=True, label='Type here a name for your new embedding that will store the result of next parsing/merging by the button above: (optional; cleared on success)')
|
||
with gradio.Row():
|
||
gr_tensors = gradio.Checkbox(value=True,label="Save as .safetensors AND create sd1/sdxl converted embeddings accordingly (so you can load them as a separate L part for different architecture)")
|
||
gr_button.click(fn=gr_func, inputs=[gr_name,gr_text,gr_radio,gr_tensors,gr_true], outputs=[gr_html,gr_name,gr_text], show_progress=False)
|
||
gr_radio.change(fn=gr_func, inputs=[gr_name,gr_text,gr_radio,gr_tensors,gr_false], outputs=[gr_html,gr_name,gr_text], show_progress=False)
|
||
return [(block,'EM','embedding_merge_extension')]
|
||
|
||
def tokens_to_text():
|
||
try:
|
||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui-tokenizer
|
||
class VanillaClip:
|
||
def __init__(self, clip):
|
||
self.clip = clip
|
||
def vocab(self):
|
||
return self.clip.tokenizer.get_vocab()
|
||
def byte_decoder(self):
|
||
return self.clip.tokenizer.byte_decoder
|
||
class OpenClip:
|
||
def __init__(self, clip):
|
||
self.clip = clip
|
||
self.tokenizer = open_clip.tokenizer._tokenizer
|
||
def vocab(self):
|
||
return self.tokenizer.encoder
|
||
def byte_decoder(self):
|
||
return self.tokenizer.byte_decoder
|
||
clip = shared.sd_model.cond_stage_model
|
||
if hasattr(clip,'embedders'):
|
||
clip = clip.embedders[0]
|
||
if clip is None:
|
||
return None
|
||
clip = clip.wrapped
|
||
typename = type(clip).__name__.split('.')[-1]
|
||
if typename=='FrozenOpenCLIPEmbedder':
|
||
clip = OpenClip(clip)
|
||
else:
|
||
clip = VanillaClip(clip)
|
||
vocab = {v: k for k, v in clip.vocab().items()}
|
||
byte_decoder = clip.byte_decoder()
|
||
def _tokens_to_text(tokens):
|
||
nonlocal vocab, byte_decoder
|
||
code = []
|
||
ids = []
|
||
current_ids = []
|
||
class_index = 0
|
||
def dump(last=False):
|
||
nonlocal code, ids, current_ids
|
||
words = [vocab.get(x, '') for x in current_ids]
|
||
try:
|
||
word = bytearray([byte_decoder[x] for x in ''.join(words)]).decode('utf-8')
|
||
except UnicodeDecodeError:
|
||
if last:
|
||
word = '<ERR>' * len(current_ids)
|
||
elif len(current_ids) > 4:
|
||
id = current_ids[0]
|
||
ids += [id]
|
||
local_ids = current_ids[1:]
|
||
code += [([id], '<ERR>')]
|
||
|
||
current_ids = []
|
||
for id in local_ids:
|
||
current_ids.append(id)
|
||
dump()
|
||
return
|
||
else:
|
||
return
|
||
word = word.replace('</w>', ' ')
|
||
code += [(current_ids, word)]
|
||
ids += current_ids
|
||
current_ids = []
|
||
for token in tokens:
|
||
token = int(token)
|
||
current_ids.append(token)
|
||
dump()
|
||
dump(last=True)
|
||
return [c for c in code if len(c[0])!=0]
|
||
return _tokens_to_text
|
||
except:
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
def str_to_escape(line):
|
||
res = re.sub(r'([()[\]\\])',r'\\\1',line)
|
||
return res
|
||
|
||
def get_model_clips():
|
||
sd_model = shared.sd_model
|
||
clip = sd_model.cond_stage_model
|
||
if clip is None:
|
||
clip = sd_model.text_processing_engine if hasattr(sd_model,'text_processing_engine') else None
|
||
if clip is None:
|
||
clip_l = sd_model.text_processing_engine_l if hasattr(sd_model,'text_processing_engine_l') else None
|
||
clip_g = sd_model.text_processing_engine_g if hasattr(sd_model,'text_processing_engine_g') else None
|
||
if clip_l is not None:
|
||
if clip_g is not None:
|
||
return (clip_l,clip_g)
|
||
return (clip_l,)
|
||
raise Exception_From_EmbeddingMergeExtension('Could not find CLIP model!')
|
||
if(hasattr(clip,'embedders')):
|
||
try:
|
||
return (clip.embedders[0],clip.embedders[1]) # SDXL
|
||
except:
|
||
pass
|
||
return (clip,) # SD1 or SD2
|
||
|
||
def get_embedding_db():
|
||
try:
|
||
db = modules.sd_hijack.model_hijack.embedding_db
|
||
if db is not None:
|
||
return (db,)
|
||
except:
|
||
pass
|
||
clips = get_model_clips()
|
||
return [c.embeddings for c in clips]
|
||
|
||
def tokenize_line(clip,text):
|
||
if hasattr(clip,'encode_embedding_init_text'):
|
||
return clip.tokenize_line(str_to_escape(text))
|
||
old = clip.emphasis.name
|
||
clip.emphasis.name = 'None'
|
||
try:
|
||
res = clip.tokenize_line(text)
|
||
finally:
|
||
clip.emphasis.name = old
|
||
return res
|
||
|
||
def encode_embedding_init_text(clip,text,length=999):
|
||
if hasattr(clip,'encode_embedding_init_text'):
|
||
return clip.encode_embedding_init_text(text,length)
|
||
part = tokenize_line(clip,text)
|
||
tokens = part[0][0].tokens
|
||
return clip.text_encoder.transformer.text_model.embeddings.token_embedding.wrapped(torch.tensor(tokens[1:part[1]+1]))
|
||
|
||
def text_to_vectors(orig_text):
|
||
try:
|
||
both = []
|
||
for clip,lg in zip(get_model_clips(),('clip_l','clip_g')):
|
||
res = []
|
||
text = orig_text.lstrip().lower()
|
||
tokens = tokenize_line(clip,text)
|
||
count = tokens[1]
|
||
tokens = tokens[0][0]
|
||
fixes = tokens.fixes
|
||
if count>=len(tokens.tokens):
|
||
return None
|
||
tokens = tokens.tokens[1:count+1]
|
||
start = 0
|
||
for fix in fixes:
|
||
name = fix.embedding.name.lower()
|
||
tensor = fix.embedding.vec
|
||
if type(tensor)==dict:
|
||
tensor = tensor[lg]
|
||
num = fix.embedding.vectors
|
||
off = fix.offset
|
||
if num!=tensor.size(0):
|
||
return None
|
||
lenname = len(name)
|
||
if off!=start:
|
||
test = 0
|
||
while True:
|
||
pos = text.find(name,test)
|
||
if pos<0:
|
||
return None
|
||
test = pos+lenname
|
||
sub = text[0:test]
|
||
part = tokenize_line(clip,sub)
|
||
cnt = part[1]
|
||
part = part[0][0]
|
||
vec = off-start
|
||
need = tokens[start:off+num]
|
||
if part.tokens[1:cnt+1]==need:
|
||
trans = encode_embedding_init_text(clip,text,vec)
|
||
t = trans[:vec].to(device=devices.device,dtype=torch.float32)
|
||
res.append((t,sub[:pos],need[:vec]))
|
||
text = text[pos:]
|
||
start = off
|
||
break
|
||
if text[0:lenname]!=name:
|
||
return None
|
||
tensor = tensor.to(device=devices.device,dtype=torch.float32)
|
||
res.append((tensor,name,None))
|
||
start += num
|
||
text = text[lenname:].lstrip()
|
||
if text!='':
|
||
part = tokenize_line(clip,text)
|
||
cnt = part[1]
|
||
part = part[0][0]
|
||
need = tokens[start:]
|
||
if part.tokens[1:cnt+1]!=need:
|
||
return None
|
||
trans = encode_embedding_init_text(clip,text,999)
|
||
trans = trans.to(device=devices.device,dtype=torch.float32)
|
||
res.append((trans,text,need))
|
||
both.append(res)
|
||
return both
|
||
except:
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
def text_to_tokens(text):
|
||
try:
|
||
both = []
|
||
for clip in get_model_clips():
|
||
tokens = clip.tokenize([text])[0]
|
||
both.append(tokens)
|
||
if len(both)>1:
|
||
if (both[0]-both[1]).abs().max().item() != 0:
|
||
print('EM: text_to_tokens',both)
|
||
return None
|
||
return both[0]
|
||
except:
|
||
return None
|
||
|
||
def tokens_to_vectors(pair):
|
||
try:
|
||
res = []
|
||
for clip,arr in zip(get_model_clips(),pair):
|
||
clip = clip.wrapped
|
||
if hasattr(clip,'model') and hasattr(clip.model,'token_embedding'):
|
||
tensor = torch.tensor([arr],dtype=torch.int,device=devices.device)
|
||
tokens = clip.model.token_embedding.wrapped(tensor).to(devices.device)
|
||
else:
|
||
token_embedding = clip.transformer.text_model.embeddings.token_embedding
|
||
tensor = torch.tensor([arr],dtype=torch.int,device=token_embedding.wrapped.weight.device)
|
||
tokens = token_embedding.wrapped(tensor).to(devices.device)
|
||
res.append(tokens)
|
||
if len(res)>1:
|
||
if len(res[0]) != len(res[1]):
|
||
print('EM: tokens_to_vectors',res)
|
||
return None
|
||
return res
|
||
except:
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
def to_float(num):
|
||
if num is None:
|
||
return None
|
||
try:
|
||
return float(num)
|
||
except:
|
||
return None
|
||
|
||
def to_int(num):
|
||
if num is None:
|
||
return None
|
||
try:
|
||
return int(num)
|
||
except:
|
||
return None
|
||
|
||
def grab_vectors(text):
|
||
try:
|
||
both = []
|
||
for res in text_to_vectors(text):
|
||
if res is None:
|
||
return None
|
||
if len(res)==0:
|
||
res = text_to_vectors(',')[len(both)][0][0][0:0]
|
||
else:
|
||
res = torch.cat([ten[0] for ten in res]);
|
||
both.append(res)
|
||
if len(both)>1:
|
||
if len(both[0]) != len(both[1]):
|
||
print('EM: grab_vectors',both)
|
||
return None
|
||
return both
|
||
except:
|
||
return None
|
||
|
||
reg_clean = re.compile(r'\s+')
|
||
reg_oper = re.compile(r'(=?)(?:([*/,])([+-]?[0-9]*(?:\.[0-9]*)?(?:L|G)?)|:([+-]?)(-?[0-9]+))')
|
||
sdxl_sizes = {
|
||
'L': 768,
|
||
'G': 1280,
|
||
}
|
||
def merge_parser(text,only_count):
|
||
clips = get_model_clips()
|
||
vocab = None
|
||
def check_vocab(token2):
|
||
nonlocal vocab
|
||
if vocab is None:
|
||
vocab = []
|
||
for clip in clips:
|
||
wrapped = clip.wrapped
|
||
typename = type(wrapped).__name__.split('.')[-1]
|
||
if typename=='FrozenCLIPEmbedder':
|
||
voc = wrapped.tokenizer.get_vocab()
|
||
elif typename=='FrozenOpenCLIPEmbedder':
|
||
voc = open_clip.tokenizer._tokenizer.encoder
|
||
else:
|
||
return True
|
||
vocab.append({v: k for k, v in voc.items()})
|
||
t = token2[0]
|
||
if len(vocab)>1:
|
||
if len(token2)>1:
|
||
return (t in vocab[0]) and (token2[1] in vocab[1])
|
||
return (t in vocab[0]) and (t in vocab[1])
|
||
return t in vocab[0]
|
||
orig = '"'+text+'"'
|
||
text = text.replace('\0',' ')+' '
|
||
length = len(text)
|
||
arr = []
|
||
left = 0
|
||
quot = False
|
||
join = False
|
||
while left<length:
|
||
pos = text.find("'",left)
|
||
if pos<0:
|
||
pos = length
|
||
take = text[left:pos]
|
||
if left>0:
|
||
if take=='' and not quot:
|
||
join = True
|
||
elif quot:
|
||
if join:
|
||
arr[-1] = (arr[-1][0]+"'"+take,True)
|
||
join = False
|
||
else:
|
||
arr.append((take,True))
|
||
else:
|
||
arr.append((take.strip(),False))
|
||
quot = not quot
|
||
left = pos+1
|
||
if not quot:
|
||
return (None,'Last quote not closed in '+orig)
|
||
if len(arr)>0 and arr[-1][0]=='':
|
||
arr.pop()
|
||
actions = []
|
||
combine = False
|
||
for param, quot in arr:
|
||
one = param
|
||
if quot:
|
||
if combine:
|
||
actions[-1]['V'] = param
|
||
combine = False
|
||
else:
|
||
actions.append({
|
||
'A': None,
|
||
'V': param,
|
||
'O': one,
|
||
})
|
||
continue
|
||
elif combine:
|
||
return (None,'Wrong concatenation "'+param+'" in '+orig)
|
||
param = reg_clean.sub('',param)
|
||
while param!='':
|
||
m = reg_oper.match(param)
|
||
if not m:
|
||
if param=='+' or param=='-':
|
||
actions.append({
|
||
'A': False,
|
||
'V': param=='+',
|
||
'O': one,
|
||
})
|
||
break
|
||
return (None,'Wrong expression "'+param+'" in '+orig)
|
||
m_flag = m.group(1)=='='
|
||
m_mul = m.group(2)
|
||
m_val = m.group(3)
|
||
m_shift = m.group(4)
|
||
m_size = m.group(5)
|
||
m_tok = -1
|
||
m_clip = None
|
||
if m_val is not None:
|
||
if len(m_val)>0:
|
||
m_clip = m_val[-1]
|
||
if (m_clip=='L') or (m_clip=='G'):
|
||
m_val = m_val[:-1]
|
||
if len(clips)<2:
|
||
return (None,'Suffix L or G can be used with SDXL models only: "'+param+'" in '+orig)
|
||
else:
|
||
m_clip = None
|
||
if m_mul==',':
|
||
if m_flag:
|
||
return (None,'Concatenation doesn\'t support \'=\' prefix: "'+param+'" in '+orig)
|
||
if m_clip is not None:
|
||
return (None,'Concatenation doesn\'t support L or G suffix: "'+param+'" in '+orig)
|
||
if (len(m_val)>0) and (m_val[0]=='0'):
|
||
if m_val=='0':
|
||
m_tok = 0
|
||
elif m_val=='00':
|
||
m_tok = -2
|
||
elif m_val=='000':
|
||
m_tok = -3
|
||
elif m_val=='0000':
|
||
m_tok = -4
|
||
else:
|
||
m_tok = None
|
||
elif m_val=='':
|
||
m_tok = -5
|
||
combine = True
|
||
m_val = None
|
||
else:
|
||
m_tok = to_int(m_val)
|
||
if (m_tok is not None) and not (m_tok>=0):
|
||
m_tok = None
|
||
if m_tok is None:
|
||
return (None,'Bad param for concatenation "'+param+'" in '+orig)
|
||
else:
|
||
m_val = to_float(m_val)
|
||
if m_val is None:
|
||
return (None,'Bad param for multiplication "'+param+'" in '+orig)
|
||
m_mul = m_mul=='*'
|
||
m_size = -1
|
||
m_shift = 0
|
||
else:
|
||
m_size = int(m_size)
|
||
if m_shift=='+':
|
||
m_shift = m_size
|
||
m_size = -1
|
||
elif m_shift=='-':
|
||
m_shift = -m_size
|
||
m_size = -1
|
||
else:
|
||
m_shift = 0
|
||
m_val = 1
|
||
m_mul = None
|
||
actions.append({
|
||
'A': True,
|
||
'V': m_val,
|
||
'W': m_mul,
|
||
'S': m_size,
|
||
'R': m_shift,
|
||
'F': m_flag,
|
||
'T': m_tok,
|
||
'C': m_clip,
|
||
'O': one,
|
||
})
|
||
param = param[len(m.group(0)):]
|
||
if combine:
|
||
return (None,'Unfinished concatenation in '+orig)
|
||
actions.append({
|
||
'A': None,
|
||
'V': None,
|
||
})
|
||
can_file = True
|
||
can_add = False
|
||
can_mul = False
|
||
for act in actions:
|
||
act['M'] = False
|
||
A = act['A']
|
||
if A==None:
|
||
if act['V']==None:
|
||
if can_file:
|
||
return (None,'Need quoted string after last + or - in '+orig)
|
||
act['M'] = True
|
||
break
|
||
if can_file:
|
||
can_add = True
|
||
can_mul = True
|
||
can_file = False
|
||
else:
|
||
return (None,'Quoted string without preceding + or - at \''+act['O']+'\' in '+orig)
|
||
elif A==True:
|
||
if can_mul:
|
||
can_file = False
|
||
can_add = True
|
||
can_mul = True
|
||
if act['F']:
|
||
act['M'] = True
|
||
else:
|
||
return (None,'Cannot multiply or modify at "'+act['O']+'" in '+orig)
|
||
else:
|
||
if can_add:
|
||
can_file = True
|
||
can_mul = False
|
||
can_add = False
|
||
act['M'] = True
|
||
else:
|
||
return (None,'Cannot merge at "'+act['O']+'" in '+orig)
|
||
left = None
|
||
right = None
|
||
add = 0
|
||
for act in actions:
|
||
if act['M'] and (left is not None):
|
||
if add!=0:
|
||
if only_count:
|
||
if left>right:
|
||
right = left
|
||
else:
|
||
(vectors1_0,length1_0) = left[0].size()
|
||
(vectors2_0,length2_0) = right[0].size()
|
||
(vectors1_1,length1_1) = left[1].size() if len(left)>1 else (vectors1_0,length1_0)
|
||
(vectors2_1,length2_1) = right[1].size() if len(right)>1 else (vectors2_0,length2_0)
|
||
if (length1_0!=length2_0) or (length1_1!=length2_1) or (vectors1_0!=vectors1_1) or (vectors2_0!=vectors2_1) or (len(left)!=len(right)):
|
||
return (None,'Cannot merge different embeddings in '+orig)
|
||
if vectors1_0!=vectors2_0:
|
||
if vectors1_0<vectors2_0:
|
||
target = [torch.zeros(vectors2_0,length1_0).to(device=devices.device,dtype=torch.float32)]
|
||
target[0][0:vectors1_0] = left[0]
|
||
if len(left)>1:
|
||
target.append(torch.zeros(vectors2_1,length1_1).to(device=devices.device,dtype=torch.float32))
|
||
target[1][0:vectors1_1] = left[1]
|
||
left = target
|
||
else:
|
||
target = [torch.zeros(vectors1_0,length2_0).to(device=devices.device,dtype=torch.float32)]
|
||
target[0][0:vectors2_0] = right[0]
|
||
if len(right)>1:
|
||
target.append(torch.zeros(vectors1_1,length2_1).to(device=devices.device,dtype=torch.float32))
|
||
target[1][0:vectors2_1] = right[1]
|
||
right = target
|
||
if add>0:
|
||
right[0] = left[0]+right[0]
|
||
if len(left)>1 and len(right)>1:
|
||
right[1] = left[1]+right[1]
|
||
else:
|
||
right[0] = left[0]-right[0]
|
||
if len(left)>1 and len(right)>1:
|
||
right[1] = left[1]-right[1]
|
||
left = None
|
||
A = act['A']
|
||
if A==None:
|
||
line = act['V']
|
||
if line==None:
|
||
return (right,None)
|
||
right = grab_vectors(line)
|
||
if right==None:
|
||
return (None,'Failed to parse \''+line+'\' in '+orig)
|
||
if only_count:
|
||
right = right[0].size(0)
|
||
elif A==False:
|
||
if act['V']:
|
||
add = 1
|
||
else:
|
||
add = -1
|
||
left = right
|
||
right = None
|
||
else:
|
||
s = act['S']
|
||
r = act['R']
|
||
t = act['T']
|
||
if only_count:
|
||
if t!=-1:
|
||
right += 1
|
||
elif (r==0)and(s>=0):
|
||
right = s
|
||
else:
|
||
if t!=-1:
|
||
if t<0:
|
||
if t==-2:
|
||
t = [clip.id_pad for clip in clips]
|
||
elif t==-3:
|
||
t = [clip.id_end for clip in clips]
|
||
elif t==-4:
|
||
t = [clip.id_start for clip in clips]
|
||
else:
|
||
res = grab_vectors(act['V'])
|
||
t = None
|
||
if res is None:
|
||
return (None,'Failed to parse \''+act['V']+'\' in '+orig)
|
||
else:
|
||
if len(clips)>1:
|
||
t = [t,t]
|
||
else:
|
||
t = [t]
|
||
if t is not None:
|
||
if not check_vocab(t):
|
||
return (None,'Unknown token value \''+str(t[0])+'\' in '+orig)
|
||
res = tokens_to_vectors(t)
|
||
if res is None:
|
||
return (None,'Failed to convert token \''+str(t)+'\' in '+orig)
|
||
if right is None:
|
||
right = res
|
||
else:
|
||
if len(right)>1 and len(res)>1:
|
||
right = [torch.cat([right[0],res[0]]),torch.cat([right[1],res[1]])]
|
||
else:
|
||
right = [torch.cat([right[0],res[0]])]
|
||
elif r!=0:
|
||
right[0] = right[0].roll(r,dims=0)
|
||
if len(right)>1:
|
||
right[1] = right[1].roll(r,dims=0)
|
||
else:
|
||
if s>=0:
|
||
(vectors,length) = right[0].size()
|
||
if vectors>s:
|
||
if len(right)>1:
|
||
right = [right[0][0:s],right[1][0:s]]
|
||
else:
|
||
right[0] = right[0][0:s]
|
||
elif vectors<s:
|
||
target = [torch.zeros(s,length).to(device=devices.device,dtype=torch.float32)]
|
||
target[0][0:vectors] = right[0]
|
||
if len(right)>1:
|
||
(vectors,length) = right[1].size()
|
||
target.append(torch.zeros(s,length).to(device=devices.device,dtype=torch.float32))
|
||
target[1][0:vectors] = right[1]
|
||
right = target
|
||
elif act['W']==True:
|
||
if act['C']==None:
|
||
right = [r*act['V'] for r in right]
|
||
else:
|
||
s = sdxl_sizes[act['C']]
|
||
right = [(r*act['V'] if r.shape[-1]==s else r) for r in right]
|
||
elif act['W']==False:
|
||
if act['C']==None:
|
||
right = [r/act['V'] for r in right]
|
||
else:
|
||
s = sdxl_sizes[act['C']]
|
||
right = [(r/act['V'] if r.shape[-1]==s else r) for r in right]
|
||
return (right,None)
|
||
|
||
def grab_embedding_cache():
|
||
db = get_embedding_db()[0]
|
||
field = '__embedding_merge_cache_'
|
||
if hasattr(db,field):
|
||
cache = getattr(db,field)
|
||
else:
|
||
cache = {'_':0,'-':0,'/':0}
|
||
setattr(db,field,cache)
|
||
return cache
|
||
|
||
def register_embedding(name,embedding):
|
||
for self in get_embedding_db():
|
||
model = shared.sd_model
|
||
if hasattr(self,'register_embedding_by_name'):
|
||
try:
|
||
return self.register_embedding_by_name(embedding,model,name)
|
||
except TypeError:
|
||
return self.register_embedding_by_name(embedding,name)
|
||
# /modules/textual_inversion/textual_inversion.py
|
||
try:
|
||
ids = model.cond_stage_model.tokenize([name])[0]
|
||
first_id = ids[0]
|
||
except:
|
||
return
|
||
if embedding is None:
|
||
if self.word_embeddings[name] is None:
|
||
return
|
||
del self.word_embeddings[name]
|
||
else:
|
||
self.word_embeddings[name] = embedding
|
||
if first_id not in self.ids_lookup:
|
||
if embedding is None:
|
||
return
|
||
self.ids_lookup[first_id] = []
|
||
save = [(ids, embedding)] if embedding is not None else []
|
||
old = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
|
||
self.ids_lookup[first_id] = sorted(old + save, key=lambda x: len(x[0]), reverse=True)
|
||
return embedding
|
||
|
||
def make_temp_embedding(name,vectors,cache,fake):
|
||
embed = None
|
||
if name in cache:
|
||
embed = cache[name]
|
||
if fake>0:
|
||
return
|
||
else:
|
||
if fake>0:
|
||
if len(get_model_clips())>1:
|
||
vectors = [torch.zeros((fake,16)),torch.zeros((fake,16))]
|
||
else:
|
||
vectors = [torch.zeros((fake,16))]
|
||
shape = vectors[-1].size()
|
||
if len(vectors)>1:
|
||
vectors = {'clip_g':vectors[1],'clip_l':vectors[0]}
|
||
else:
|
||
vectors = vectors[0]
|
||
if embed is None:
|
||
embed = Embedding(vectors,name)
|
||
cache[name] = embed
|
||
embed.vec = vectors
|
||
embed.step = None
|
||
embed.vectors = shape[0]
|
||
embed.shape = shape[-1]
|
||
embed.cached_checksum = None
|
||
embed.filename = ''
|
||
register_embedding(name,embed)
|
||
|
||
def reset_temp_embeddings(prod,unregister):
|
||
cache = grab_embedding_cache()
|
||
num = cache[prod]
|
||
cache[prod] = 0
|
||
for a,b in (('<','>'),('{','}')):
|
||
i = num
|
||
while i>0:
|
||
tgt = a+"'EM"+prod+str(i)+"'"+b
|
||
if tgt in cache:
|
||
embed = cache[tgt]
|
||
if type(embed.vec)==dict:
|
||
for k,v in embed.vec.items():
|
||
embed.vec[k] = torch.zeros((0,v.shape[-1]),device=v.device)
|
||
else:
|
||
embed.vec = torch.zeros((0,embed.vec.shape[-1]),device=embed.vec.device)
|
||
embed.vectors = 0
|
||
embed.cached_checksum = None
|
||
del cache[tgt]
|
||
if unregister:
|
||
register_embedding(tgt,None)
|
||
i = i-1
|
||
return cache
|
||
|
||
def add_temp_embedding(vectors,cache,prod,curly,fake):
|
||
if fake>0:
|
||
prod = '/'
|
||
num = (cache[prod] or 0)
|
||
if fake>num:
|
||
cache[prod] = fake
|
||
num = fake
|
||
else:
|
||
prod = '_' if prod else '-'
|
||
num = 1+(cache[prod] or 0)
|
||
cache[prod] = num
|
||
name = "'EM"+prod+str(num)+"'"
|
||
if curly:
|
||
name = '{'+name+'}'
|
||
else:
|
||
name = '<'+name+'>'
|
||
make_temp_embedding(name,vectors,cache,fake)
|
||
return name
|
||
|
||
def parse_infotext(text):
|
||
orig = text
|
||
text += '\n'
|
||
pos = re.search(r"\bEmbeddingMerge:\s*(\"?[<{])'EM_",text)
|
||
if pos is None:
|
||
return (None,orig)
|
||
head = text[:pos.span(0)[0]].rstrip()
|
||
if len(head)>0 and head[-1]==',':
|
||
head = head[:-1]
|
||
text = text[pos.span(1)[0]:]
|
||
if len(text)<2:
|
||
return (None,orig)
|
||
what = text[0]
|
||
if what=='"':
|
||
unquoted = None
|
||
else:
|
||
if what=='<':
|
||
unquoted = '>'
|
||
elif what=='{':
|
||
unquoted = '}'
|
||
else:
|
||
return (None,orig)
|
||
if unquoted is not None:
|
||
stop = min_or_all(text.find(unquoted+','),text.find(unquoted+'\n'),-1)
|
||
if stop<0:
|
||
return (None,orig)
|
||
stop += 1
|
||
tail = text[stop:]
|
||
line = text[:stop]
|
||
else:
|
||
stop = (text+'\n').find('\n')
|
||
part = text[:stop]
|
||
left = 0
|
||
while True:
|
||
right = part.find('"',left+1)
|
||
if right<0:
|
||
return (None,orig)
|
||
try:
|
||
line = json.loads('['+part[:right+1].strip()+']')[0]
|
||
break
|
||
except:
|
||
left = right
|
||
tail = part[right+1:]+text[stop:]
|
||
return (line,head+tail)
|
||
|
||
def parse_mergeseq(seq):
|
||
res = None
|
||
seq = seq.lstrip()
|
||
while True:
|
||
left = seq[0:5]
|
||
if left=="<'EM_":
|
||
right = "'>="
|
||
elif left=="{'EM_":
|
||
right = "'}="
|
||
else:
|
||
return res
|
||
stop = seq.find(right)
|
||
if stop<1:
|
||
return res
|
||
what = seq[0:stop+2]
|
||
seq = seq[stop+3:]
|
||
left = seq[0:2]
|
||
if left=="<'":
|
||
right = '>, '
|
||
elif left=="{'":
|
||
right = '}, '
|
||
else:
|
||
return res
|
||
stop = min_or_all(seq.find(right+"<'"),seq.find(right+"{'"),len(seq))
|
||
repl = seq[0:stop+1]
|
||
seq = seq[stop+3:]
|
||
if res is None:
|
||
res = {}
|
||
res[what] = repl
|
||
|
||
def min_or_all(a,b,n):
|
||
if a>=0:
|
||
if b>=0:
|
||
if a<b:
|
||
return a
|
||
return b
|
||
else:
|
||
return a
|
||
elif b>=0:
|
||
return b
|
||
return n
|
||
|
||
def dict_replace(di,text):
|
||
for key in di:
|
||
text = text.replace(key,di[key])
|
||
return text
|
||
|
||
gr_lock = threading.Lock()
|
||
|
||
def gr_func(gr_name,gr_text,gr_radio,gr_tensors,store):
|
||
with gr_lock:
|
||
try:
|
||
sd_models.reload_model_weights()
|
||
except:
|
||
pass
|
||
try:
|
||
sd_models.forge_model_reload()
|
||
except:
|
||
pass
|
||
gr_orig = gr_text
|
||
font = 'font-family:Consolas,Courier New,Courier,monospace;'
|
||
table = '<style>.webui_embedding_merge_table,.webui_embedding_merge_table td,.webui_embedding_merge_table th{border:1px solid gray;border-collapse:collapse}.webui_embedding_merge_table td,.webui_embedding_merge_table th{padding:2px 5px !important;text-align:center !important;vertical-align:middle;'+font+'font-weight:bold;}.webui_embedding_merge_table{margin:6px auto !important;}</style>'
|
||
(reparse,request) = parse_infotext(gr_text)
|
||
if reparse is not None:
|
||
reparse = parse_mergeseq(reparse)
|
||
if reparse is None:
|
||
return ('<center><b>Prompt restore failed!</n></center>',gr_name,gr_orig)
|
||
else:
|
||
request = dict_replace(reparse,request)
|
||
return ('<center><b>Prompt restored.</n></center>',gr_name,request)
|
||
if gr_text[:1]=="'":
|
||
(two,err) = merge_parser(gr_text,False)
|
||
if (two is not None) and two[0].numel()==0:
|
||
err = 'Result is ZERO vectors!'
|
||
if err is not None:
|
||
txt = '<b style="'+font+'">'+html.escape(err)+'</b>'
|
||
else:
|
||
txt = table
|
||
both = False
|
||
for res in two:
|
||
if res is None:
|
||
continue
|
||
if both:
|
||
txt += '<strong>↑ CLIP (L) / OpenClip (G) ↓</strong>'
|
||
txt += '<table class="webui_embedding_merge_table"><tr><th>Index</th><th>Min</th><th>Max</th><th>Sum</th><th>Abs</th><th>Len</th><th>Std</th>'
|
||
i = 1
|
||
for one in res:
|
||
txt += '<tr><td>{}</td>{}</tr>'.format(i,tensor_info(one))
|
||
i += 1
|
||
txt += '<tr><td colspan="7"> </td></tr>'
|
||
txt += '<tr><td>ALL:</td>{}</tr>'.format(tensor_info(res))
|
||
txt += '</table>'
|
||
both = True
|
||
return ('<center>'+txt+'</center>',need_save_embed(store,gr_name,two,gr_tensors),gr_orig)
|
||
if gr_text.find("<'")>=0 or gr_text.find("{'")>=0:
|
||
cache = reset_temp_embeddings('-',False)
|
||
used = {}
|
||
(mer,err) = merge_one_prompt(cache,None,{},used,gr_text,False,False)
|
||
if err is not None:
|
||
txt = '<b style="'+font+'">Embedding Merge failed - '+html.escape(err)+'</b>'
|
||
return ('<center>'+txt+'</center>',gr_name,gr_orig)
|
||
gr_text = mer
|
||
by_none = 0
|
||
by_comma = 1
|
||
by_parts = 2
|
||
by_words = 3
|
||
by_tokens = 4
|
||
by_vectors = 5
|
||
tok2txt = tokens_to_text()
|
||
if gr_radio!=by_comma:
|
||
two = text_to_vectors(gr_text)
|
||
if (gr_radio==by_none) and (two is not None) and (len(two[0])!=0):
|
||
two = [[r] for r in two]
|
||
else:
|
||
two = [[],[]]
|
||
split = gr_text.split(',')
|
||
for part in split:
|
||
one = text_to_vectors(part.strip())
|
||
if one:
|
||
two[0].append(one[0])
|
||
if(len(one)>1):
|
||
two[1].append(one[1])
|
||
else:
|
||
two[1] = None
|
||
else:
|
||
two = None
|
||
break
|
||
if (two is None) or (len(two[0])==0):
|
||
if gr_text.strip()=='':
|
||
return ('',gr_name,gr_orig)
|
||
txt = '<b>Failed to parse! (Possibly there are more than 75 tokens; or extra spaces inside embed names). Embeddings are not shown now:</b><br/><br/>'
|
||
tokens = text_to_tokens(gr_text)
|
||
if tokens:
|
||
txt += table+'<tr><th>Index</th><th>Vectors</th><th>Text</th><th>Token</th></tr>'
|
||
if tok2txt:
|
||
pairs = tok2txt(tokens)
|
||
else:
|
||
pairs = [([tok],'<ERROR>') for tok in tokens]
|
||
index = 1
|
||
for arr, text in pairs:
|
||
length = len(arr)
|
||
if length==0:
|
||
continue
|
||
txt += '<tr><td>'+(str(index) if length==1 else str(index)+'-'+str(index+length-1))+'</td><td>'+str(length)+'</td><td>'+html.escape('"'+text+'"')+'</td><td>'+(', '.join([str(a) for a in arr]))+'</td></tr>'
|
||
index += length
|
||
txt += '</table>'
|
||
return ('<center>'+txt+'</center>',gr_name,gr_orig)
|
||
both = []
|
||
for res in two:
|
||
if res is None:
|
||
continue
|
||
txt = '<table class="webui_embedding_merge_table"><tr><th>Index</th><th>Vectors</th><th>Text</th><th>Token</th><th>Min</th><th>Max</th><th>Sum</th><th>Abs</th><th>Len</th><th>Std</th></tr>'
|
||
index = 1
|
||
join = False
|
||
if gr_radio==by_words:
|
||
join = True
|
||
gr_radio = by_tokens
|
||
elif (gr_radio==by_none) or (gr_radio==by_comma):
|
||
r_res = []
|
||
for one in res:
|
||
r_tensor = []
|
||
r_name = ''
|
||
r_tokens = []
|
||
for tensor, name, tokens in one:
|
||
r_tensor.append(tensor)
|
||
if tok2txt and tokens and gr_radio==by_none:
|
||
split = tok2txt(tokens)
|
||
name = ''
|
||
tokens = []
|
||
for s_tokens, s_name in split:
|
||
name += s_name
|
||
tokens += s_tokens
|
||
r_name += name
|
||
if tokens:
|
||
r_tokens += tokens
|
||
else:
|
||
r_tokens += ['*_'+str(tensor.size(0))]
|
||
if gr_radio==by_none:
|
||
r_name += ' '
|
||
r_res.append((torch.cat(r_tensor),r_name,r_tokens))
|
||
res = r_res
|
||
gr_radio = by_parts
|
||
for tensor, name, tokens in res:
|
||
split = None
|
||
size = tensor.size(0)
|
||
span = ''
|
||
if gr_radio!=by_parts:
|
||
span = ' rowspan="'+str(size)+'"'
|
||
if tokens and tok2txt:
|
||
split = tok2txt(tokens)
|
||
if join:
|
||
comb = []
|
||
last = -1
|
||
for s_arr, s_text in split:
|
||
if (last<0) or (comb[last][1][-1:]==' '):
|
||
comb.append((s_arr,s_text))
|
||
last += 1
|
||
else:
|
||
comb[last] = (comb[last][0]+s_arr,comb[last][1]+s_text)
|
||
split = comb
|
||
if gr_radio==by_tokens:
|
||
if split is not None:
|
||
span = ' rowspan="'+str(len(split))+'"'
|
||
else:
|
||
span = ''
|
||
if gr_radio==by_vectors:
|
||
head = '<td'+span+'>'+str(size)+'</td>'
|
||
else:
|
||
head = '<td'+span+'>'+(str(index) if size==1 else str(index)+'-'+str(index+size-1))+'</td><td'+span+'>'+str(size)+'</td>'
|
||
if split is None:
|
||
head += '<td'+span+'>'+html.escape('"'+name+'"')+'</td>'
|
||
if (gr_radio==by_vectors) or ((gr_radio==by_tokens) and (tokens is not None)):
|
||
i = 0
|
||
part = 0
|
||
j = 0
|
||
ten = None
|
||
column = ''
|
||
toks = None
|
||
for one in list(tensor):
|
||
index += 1
|
||
i += 1
|
||
use = one
|
||
if split is not None:
|
||
if part==0:
|
||
pair = split[j]
|
||
part = len(pair[0])
|
||
if gr_radio==by_tokens:
|
||
column = '<td>'+html.escape('"'+pair[1]+'"')+'</td>'
|
||
toks = ', '.join([str(t) for t in pair[0]])
|
||
else:
|
||
column = '<td rowspan="'+str(part)+'">'+html.escape('"'+pair[1]+'"')+'</td>'
|
||
j += 1
|
||
part -= 1
|
||
if gr_radio==by_tokens:
|
||
if ten==None:
|
||
ten = []
|
||
ten.append(one)
|
||
if part>0:
|
||
continue
|
||
use = torch.stack(ten)
|
||
tok = toks if tokens else '*'
|
||
else:
|
||
tok = tokens[i-1] if tokens else '*_'+str(i)
|
||
txt += '<tr>{}{}<td>{}</td>{}</tr>'.format(('<td>'+str(index-1)+'</td>' if gr_radio==by_vectors else '')+head,column,tok,tensor_info(use))
|
||
column = ''
|
||
head = ''
|
||
ten = None
|
||
else:
|
||
index += size
|
||
txt += '<tr>{}<td>{}</td>{}</tr>'.format(head,', '.join([str(t) for t in tokens]) if tokens else '*',tensor_info(tensor))
|
||
txt += '</table>'
|
||
both.append(txt)
|
||
txt = table+'<strong>↑ CLIP (L) / OpenClip (G) ↓</strong>'.join(both)
|
||
return ('<center>'+txt+'</center>',need_save_embed(store,gr_name,two,gr_tensors),gr_orig)
|
||
|
||
def tensor_info(tensor):
|
||
return '<td>{:>-14.8f}</td><td>{:>+14.8f}</td><td>{:>+14.8f}</td><td>{:>14.8f}</td><td>{:>14.8f}</td><td>{:>14.8f}</td>'.format(tensor.min().item(),tensor.max().item(),tensor.sum().item(),tensor.abs().sum().item(),torch.linalg.norm(tensor,ord=2),tensor.std()).replace(' ',' ')
|
||
|
||
merge_dir = None
|
||
|
||
def need_save_embed(store,name,pair,tensors):
|
||
if not store:
|
||
return name
|
||
name = ''.join( x for x in name if (x.isalnum() or x in '._- ')).strip()
|
||
if name=='':
|
||
return name
|
||
try:
|
||
if type(pair[0])==list:
|
||
vectors = [torch.cat([r[0] for r in pair[0]])]
|
||
if (len(pair)>1) and (pair[1] is not None):
|
||
vectors.append(torch.cat([r[0] for r in pair[1]]))
|
||
else:
|
||
vectors = [pair[0]]
|
||
if (len(pair)>1) and (pair[1] is not None):
|
||
vectors.append(pair[1])
|
||
target = os.path.join(merge_dir,name)
|
||
if len(vectors)>1:
|
||
pt = {
|
||
'clip_g': vectors[1].cpu(),
|
||
'clip_l': vectors[0].cpu(),
|
||
}
|
||
elif not tensors:
|
||
pt = {
|
||
'string_to_token': {
|
||
'*': 265,
|
||
},
|
||
'string_to_param': {
|
||
'*': vectors[0].cpu(),
|
||
},
|
||
'name': name,
|
||
'step': 0,
|
||
'sd_checkpoint': None,
|
||
'sd_checkpoint_name': None,
|
||
}
|
||
if tensors:
|
||
res = None
|
||
else:
|
||
torch.save(pt,target+'.pt')
|
||
try:
|
||
res = torch.load(target+'.pt',map_location='cpu')
|
||
except:
|
||
res = None
|
||
if res is None:
|
||
if len(vectors)==1:
|
||
pt = {
|
||
'emb_params': vectors[0].cpu(),
|
||
}
|
||
from safetensors.torch import save_file
|
||
save_file(pt,target+'.safetensors')
|
||
try:
|
||
os.unlink(target+'.pt')
|
||
except:
|
||
pass
|
||
if tensors:
|
||
if len(vectors)>1:
|
||
for vector in vectors:
|
||
shape = vector.shape[-1]
|
||
if vector.abs().max().item() == 0:
|
||
shape = 0
|
||
if shape==768:
|
||
folder = os.path.join(merge_dir,'sd1')
|
||
else:
|
||
vector = None
|
||
try:
|
||
if vector is not None:
|
||
os.makedirs(folder)
|
||
except:
|
||
pass
|
||
target = os.path.join(folder,name)+'.safetensors'
|
||
if vector is not None:
|
||
from safetensors.torch import save_file
|
||
save_file({
|
||
'emb_params': vector.cpu(),
|
||
},target)
|
||
else:
|
||
folder = os.path.join(merge_dir,'sdxl')
|
||
vector = vectors[0]
|
||
shape = vector.shape[-1]
|
||
if vector.abs().max().item() == 0:
|
||
shape = 0
|
||
if shape==768:
|
||
s = list(vector.size())
|
||
s[-1] = 1280
|
||
pt = {
|
||
'clip_g': torch.zeros(s).cpu(),
|
||
'clip_l': vector.cpu(),
|
||
}
|
||
else:
|
||
pt = None
|
||
try:
|
||
if pt is not None:
|
||
os.makedirs(folder)
|
||
except:
|
||
pass
|
||
target = os.path.join(folder,name)+'.safetensors'
|
||
if pt is not None:
|
||
from safetensors.torch import save_file
|
||
save_file(pt,target)
|
||
for db in get_embedding_db():
|
||
try:
|
||
db.load_textual_inversion_embeddings(force_reload=True)
|
||
except:
|
||
db.load_textual_inversion_embeddings()
|
||
return ''
|
||
except:
|
||
traceback.print_exc()
|
||
return name
|
||
|
||
def embedding_merge_dir():
|
||
try:
|
||
nonlocal merge_dir
|
||
merge_dir = os.path.join(cmd_opts.embeddings_dir,'embedding_merge')
|
||
# don't actually need this, since it is a subfolder which will be read recursively:
|
||
#modules.sd_hijack.model_hijack.embedding_db.add_embedding_dir(merge_dir)
|
||
os.makedirs(merge_dir)
|
||
except:
|
||
pass
|
||
|
||
def raise_sd_error(p,msg):
|
||
class Exception_From_EmbeddingMergeExtension_():
|
||
def __getattribute__(self,_):
|
||
raise Exception_From_EmbeddingMergeExtension(msg)
|
||
p.__class__ = Exception_From_EmbeddingMergeExtension_
|
||
|
||
em_regexp = re.compile(r"<'EM[_/-]\d+'>|{'EM[_/-]\d+'}")
|
||
|
||
def merge_one_prompt(cache,texts,parts,used,prompt,prod,only_count):
|
||
#if len(get_model_clips())>1:
|
||
# return (None,'To enable SDXL support switch to "sdxl" branch of https://github.com/klimaleksus/stable-diffusion-webui-embedding-merge')
|
||
try:
|
||
cnt = 0
|
||
if (prompt is None) or (prompt==''):
|
||
return (prompt,None)
|
||
if texts is not None:
|
||
if prompt in texts:
|
||
return (texts[prompt],None)
|
||
orig = prompt
|
||
left = 0
|
||
while True:
|
||
curly = prompt.find("{'",left)
|
||
left = prompt.find("<'",left)
|
||
if (curly>=0 and curly<left) or (left<0):
|
||
left = curly
|
||
curly = True
|
||
else:
|
||
curly = False
|
||
if left<0:
|
||
if texts is not None:
|
||
texts[orig] = prompt
|
||
return (prompt,None)
|
||
eph = em_regexp.match(prompt[left:])
|
||
if eph is not None:
|
||
left += len(eph.group(0))
|
||
continue
|
||
right = left
|
||
while True:
|
||
right = prompt.find('}' if curly else '>',right+1)
|
||
if right<0:
|
||
if curly:
|
||
return (None,'Not found closing "}" after "{\'"')
|
||
else:
|
||
return (None,'Not found closing ">" after "<\'"')
|
||
if (prompt.count("'",left,right)&1)==0:
|
||
break
|
||
part = prompt[left+1:right].strip()
|
||
if part in parts:
|
||
embed = parts[part]
|
||
else:
|
||
(res,err) = merge_parser(part,only_count)
|
||
if err is not None:
|
||
return (None,err)
|
||
if only_count:
|
||
if (res is None) or (res==0):
|
||
embed = ''
|
||
else:
|
||
embed = add_temp_embedding(None,cache,prod,curly,res)
|
||
else:
|
||
if (res is None) or (res[0].numel()==0):
|
||
embed = ''
|
||
else:
|
||
embed = add_temp_embedding(res,cache,prod,curly,0)
|
||
if used is not None:
|
||
used[embed] = part
|
||
parts[part] = embed
|
||
prefix = prompt[:left].rstrip()+' '+embed
|
||
left = len(prefix)
|
||
prompt = prefix+' '+(prompt[right+1:].lstrip())
|
||
except:
|
||
traceback.print_exc()
|
||
return (None,'Fatal error?')
|
||
|
||
fake_cached_params_counter = time.time()
|
||
def fake_cached_params(self,*ar,**kw):
|
||
nonlocal fake_cached_params_counter
|
||
fake_cached_params_counter += 1
|
||
return (*(self.em_orig_cached_params(*ar,**kw)),id(_webui_embedding_merge_),fake_cached_params_counter)
|
||
|
||
cached_state = None
|
||
|
||
'''
|
||
import hunter
|
||
@hunter.wrap(local=True,actions=[hunter.VarsSnooper,hunter.CallPrinter])
|
||
def pretty_print(clas, indent=0, dupl=None):
|
||
if dupl is None:
|
||
dupl = {}
|
||
me = id(clas)
|
||
tab = ' ' * indent
|
||
if clas is None:
|
||
print(tab + ': None')
|
||
return
|
||
print(tab + type(clas).__name__ + ':')
|
||
indent += 4
|
||
tab = ' ' * indent
|
||
if me in dupl:
|
||
print(tab + '[CIRCULAR]')
|
||
return
|
||
dupl[me] = True
|
||
for k,v in clas.__dict__.items():
|
||
if '__dict__' in dir(v):
|
||
pretty_print(v,indent,dupl)
|
||
else:
|
||
print(tab + k + ': ' + str(v))
|
||
import code
|
||
code.interact(local=locals())
|
||
'''
|
||
|
||
def hook_infotext(hook):
|
||
if hasattr(processing,'create_infotext'):
|
||
field = '__embedding_merge_wrapper'
|
||
old = getattr(processing,'create_infotext')
|
||
if hasattr(old,field):
|
||
old = getattr(old,field)
|
||
if not hook:
|
||
setattr(processing,'create_infotext',old)
|
||
if hook:
|
||
def create_infotext(p,*ar,**kw):
|
||
res = old(p,*ar,**kw)
|
||
if 'EmbeddingMerge' in p.extra_generation_params:
|
||
(reparse,request) = parse_infotext(res)
|
||
if reparse is not None:
|
||
parse = parse_mergeseq(reparse)
|
||
matches = em_regexp.findall(request)
|
||
if (matches is not None) and len(matches)>0:
|
||
used = {}
|
||
for match in matches:
|
||
used[match] = True
|
||
gen = ''
|
||
drop = False
|
||
for embed,text in parse.items():
|
||
if embed in used:
|
||
gen += embed+'='+text+', '
|
||
else:
|
||
drop = True
|
||
if gen!='' and drop:
|
||
gen = gen[:-2]
|
||
orig = p.extra_generation_params['EmbeddingMerge']
|
||
if gen!=orig:
|
||
p.extra_generation_params['EmbeddingMerge'] = gen
|
||
res = old(p,*ar,**kw)
|
||
p.extra_generation_params['EmbeddingMerge'] = orig
|
||
return res
|
||
setattr(create_infotext,field,old)
|
||
setattr(processing,'create_infotext',create_infotext)
|
||
|
||
def embedding_merge_extension(p,processed):
|
||
if processed is not None:
|
||
hook_infotext(False)
|
||
return
|
||
hook_infotext(True)
|
||
nonlocal cached_state
|
||
use_hr = hasattr(p,'hr_prompt')
|
||
arr = [
|
||
p.all_prompts,
|
||
p.prompt if type(p.prompt)==list else [p.prompt],
|
||
p.all_negative_prompts,
|
||
p.negative_prompt if type(p.negative_prompt)==list else [p.negative_prompt],
|
||
]
|
||
if use_hr:
|
||
arr += [
|
||
p.all_hr_prompts,
|
||
p.hr_prompt if type(p.hr_prompt)==list else [p.hr_prompt],
|
||
p.all_hr_negative_prompts,
|
||
p.hr_negative_prompt if type(p.hr_negative_prompt)==list else [p.hr_negative_prompt],
|
||
]
|
||
restart = True
|
||
if 'EmbeddingMerge' in p.extra_generation_params:
|
||
restart = False
|
||
elif em_regexp.search(' '.join([' '.join(one) for one in arr if one is not None])) is not None:
|
||
restart = False
|
||
print("[EmbeddingMerge] WARNING: ephemeral embeddings (like <'EM_1'>) are detected!")
|
||
if restart or (cached_state is None):
|
||
cached_state = {
|
||
'cache': reset_temp_embeddings('_',False),
|
||
'texts': {},
|
||
'parts': {},
|
||
'used': {},
|
||
}
|
||
cache = cached_state['cache']
|
||
texts = cached_state['texts']
|
||
parts = cached_state['parts']
|
||
used = cached_state['used']
|
||
for one in arr:
|
||
ok = False
|
||
fail = None
|
||
if one is not None:
|
||
for i in range(len(one)):
|
||
(res,err) = merge_one_prompt(cache,texts,parts,used,one[i],True,False)
|
||
if err is not None:
|
||
if fail is None:
|
||
fail = err
|
||
else:
|
||
one[i] = res
|
||
ok = True
|
||
if not ok and fail is not None:
|
||
raise_sd_error(p,'\n\nEmbedding Merge failed - '+err+'\n')
|
||
return
|
||
p.all_prompts = arr[0]
|
||
p.all_negative_prompts = arr[2]
|
||
p.prompt = arr[1] if type(p.prompt)==list else arr[1][0]
|
||
p.negative_prompt = arr[3] if type(p.negative_prompt)==list else arr[3][0]
|
||
if use_hr:
|
||
p.all_hr_prompts = arr[4]
|
||
p.all_hr_negative_prompts = arr[6]
|
||
p.hr_prompt = arr[5] if type(p.hr_prompt)==list else arr[5][0]
|
||
p.hr_negative_prompt = arr[7] if type(p.hr_negative_prompt)==list else arr[7][0]
|
||
gen = ''
|
||
was_used = False
|
||
for embed in used:
|
||
was_used = True
|
||
if embed!='':
|
||
if embed[0]=='<':
|
||
gen += embed+'=<'+used[embed]+'>, '
|
||
else:
|
||
gen += embed+'={'+used[embed]+'}, '
|
||
if gen!='':
|
||
p.extra_generation_params['EmbeddingMerge'] = gen[:-2]
|
||
if was_used:
|
||
orig = getattr(p,'cached_params',None)
|
||
if orig is not None:
|
||
setattr(p,'em_orig_cached_params',orig)
|
||
setattr(p,'cached_params',types.MethodType(fake_cached_params,p))
|
||
|
||
try:
|
||
cls = modules.sd_hijack.StableDiffusionModelHijack
|
||
get_prompt_lengths = cls.get_prompt_lengths
|
||
field = '__embedding_merge_wrapper'
|
||
def hook_prompt_lengths(self,text,*ar,**kw):
|
||
if text.find("<'")<0 and text.find("{'")<0:
|
||
return get_prompt_lengths(self,text,*ar,**kw)
|
||
(res,err) = merge_one_prompt(grab_embedding_cache(),None,{},None,text,True,True)
|
||
if err is not None:
|
||
return -1,-1
|
||
return get_prompt_lengths(self,res,*ar,**kw)
|
||
if hasattr(get_prompt_lengths,field):
|
||
get_prompt_lengths = getattr(get_prompt_lengths,field)
|
||
setattr(hook_prompt_lengths,field,get_prompt_lengths)
|
||
cls.get_prompt_lengths = hook_prompt_lengths
|
||
except:
|
||
traceback.print_exc()
|
||
|
||
def on_infotext_pasted(infotext,result):
|
||
if 'EmbeddingMerge' in result:
|
||
reparse = result['EmbeddingMerge']
|
||
if reparse[:1]=='"':
|
||
try:
|
||
reparse = json.loads('['+reparse.strip()+']')[0]
|
||
reparse = parse_mergeseq(reparse)
|
||
except:
|
||
reparse = None
|
||
else:
|
||
reparse = parse_mergeseq(reparse)
|
||
request = None
|
||
else:
|
||
(reparse,request) = parse_infotext(infotext)
|
||
if reparse is not None:
|
||
reparse = parse_mergeseq(reparse)
|
||
if reparse is not None:
|
||
if 'Prompt' in result:
|
||
if (request is not None) and (result['Prompt']==infotext):
|
||
result['Prompt'] = request
|
||
result['Prompt'] = dict_replace(reparse,result['Prompt'])
|
||
if 'Negative prompt' in result:
|
||
result['Negative prompt'] = dict_replace(reparse,result['Negative prompt'])
|
||
if 'Hires prompt' in result:
|
||
result['Hires prompt'] = dict_replace(reparse,result['Hires prompt'])
|
||
if 'Hires negative prompt' in result:
|
||
result['Hires negative prompt'] = dict_replace(reparse,result['Hires negative prompt'])
|
||
setattr(_webui_embedding_merge_,'on_infotext_pasted',on_infotext_pasted)
|
||
def on_model_loaded(*ar,**kw):
|
||
reset_temp_embeddings('/',True)
|
||
setattr(_webui_embedding_merge_,'on_model_loaded',on_model_loaded)
|
||
|
||
def on_script_unloaded():
|
||
hook_infotext(False)
|
||
reset_temp_embeddings('_',True)
|
||
reset_temp_embeddings('-',True)
|
||
reset_temp_embeddings('/',True)
|
||
try:
|
||
cls = modules.sd_hijack.StableDiffusionModelHijack
|
||
get_prompt_lengths = cls.get_prompt_lengths
|
||
field = '__embedding_merge_wrapper'
|
||
if hasattr(get_prompt_lengths,field):
|
||
cls.get_prompt_lengths = getattr(get_prompt_lengths,field)
|
||
except:
|
||
traceback.print_exc()
|
||
try:
|
||
db = get_embedding_db()[0]
|
||
field = '__embedding_merge_cache_'
|
||
if hasattr(db,field):
|
||
delattr(db,field)
|
||
except:
|
||
traceback.print_exc()
|
||
setattr(_webui_embedding_merge_,'on_script_unloaded',on_script_unloaded)
|
||
setattr(_webui_embedding_merge_,'embedding_merge_extension',embedding_merge_extension)
|
||
embedding_merge_dir()
|
||
return gr_tab
|
||
|
||
class EmbeddingMergeExtension(scripts.Script):
|
||
def title(self):
|
||
return 'Embedding Merge'
|
||
def show(self,is_img2img):
|
||
return scripts.AlwaysVisible
|
||
def process(self,p):
|
||
if hasattr(_webui_embedding_merge_,'embedding_merge_extension'):
|
||
getattr(_webui_embedding_merge_,'embedding_merge_extension')(p,None)
|
||
def postprocess(self,p,processed):
|
||
if hasattr(_webui_embedding_merge_,'embedding_merge_extension'):
|
||
getattr(_webui_embedding_merge_,'embedding_merge_extension')(p,processed)
|
||
|
||
script_callbacks.on_ui_tabs(_webui_embedding_merge_())
|
||
script_callbacks.on_infotext_pasted(_webui_embedding_merge_.on_infotext_pasted)
|
||
script_callbacks.on_script_unloaded(_webui_embedding_merge_.on_script_unloaded)
|
||
try:
|
||
script_callbacks.on_model_loaded(_webui_embedding_merge_.on_model_loaded)
|
||
except:
|
||
pass
|
||
|
||
#EOF
|