support LORA model keywords.

pull/33/head
ChunKoo Park 2023-02-21 16:39:35 +09:00
parent a6ef316d6b
commit 167bcae847
1 changed files with 267 additions and 19 deletions

View File

@ -10,8 +10,12 @@ import random
scripts_dir = scripts.basedir()
kw_idx = 0
lora_idx = 0
hash_dict = None
hash_dict_modified = None
lora_hash_dict = None
lora_hash_dict_modified = None
model_hash_dict = {}
def str_simularity(a, b):
@ -56,6 +60,7 @@ def load_hash_dict():
kw = row[1].strip(' ')
if mhash.startswith('#'):
continue
mhash = mhash.lower()
ckptname = 'default' if len(row)<=2 else row[2].strip(' ')
hash_dict[mhash].append((kw, ckptname,idx))
except:
@ -68,6 +73,42 @@ def load_hash_dict():
return hash_dict
def load_lora_hash_dict():
global lora_hash_dict, lora_hash_dict_modified, scripts_dir
default_file = f'{scripts_dir}/lora-keyword.txt'
user_file = f'{scripts_dir}/lora-keyword-user.txt'
if not os.path.exists(user_file):
open(user_file, 'w').write('\n')
modified = str(os.stat(default_file).st_mtime) + '_' + str(os.stat(user_file).st_mtime)
if lora_hash_dict is None or lora_hash_dict_modified != modified:
lora_hash_dict = defaultdict(list)
def parse_file(path, idx):
if os.path.exists(path):
with open(path, encoding='utf-8', newline='') as csvfile:
csvreader = csv.reader(csvfile)
for row in csvreader:
try:
mhash = row[0].strip(' ')
kw = row[1].strip(' ')
if mhash.startswith('#'):
continue
mhash = mhash.lower()
ckptname = 'default' if len(row)<=2 else row[2].strip(' ')
lora_hash_dict[mhash].append((kw, ckptname,idx))
except:
pass
parse_file(default_file, 0) # 0 for default_file
parse_file(user_file, 1) # 1 for user_file
lora_hash_dict_modified = modified
return lora_hash_dict
def get_keyword_for_model(model_hash, model_ckpt, return_entry=False):
found = None
@ -93,6 +134,49 @@ def get_keyword_for_model(model_hash, model_ckpt, return_entry=False):
return found
return found[0] if found else None
def _get_keywords_for_lora(lora_model, return_entry=False):
found = None
lora_model_path = f'{shared.cmd_opts.lora_dir}/{lora_model}'
# hash -> [ (keyword, ckptname, idx) ]
lora_hash_dict = load_lora_hash_dict()
lora_model_hash = get_old_model_hash(lora_model_path)
# print(hash_dict)
if lora_model_hash in lora_hash_dict:
lst = lora_hash_dict[lora_model_hash]
if len(lst) == 1:
found = lst[0]
elif len(lst) > 1:
max_sim = 0.0
found = lst[0]
for kw_ckpt in lst:
sim = str_simularity(kw_ckpt[1], model_ckpt)
if sim >= max_sim:
max_sim = sim
found = kw_ckpt
if return_entry:
return found
return found[0] if found else None
def get_lora_keywords(lora_model, keyword_only=False):
lora_keywords = ["None"]
if lora_model != 'None':
kws = _get_keywords_for_lora(lora_model)
if kws:
words = [x.strip() for x in kws.split('|')]
if keyword_only:
return words
if len(words) > 1:
words.insert(0, ', '.join(words))
words.append('< iterate >')
words.append('< random >')
lora_keywords.extend(words)
return lora_keywords
settings = None
def save_settings(m):
@ -104,6 +188,8 @@ def save_settings(m):
for k in m.keys():
settings[k] = m[k]
# print(settings)
settings_file = f'{scripts_dir}/settings.txt'
lines = []
@ -124,6 +210,9 @@ def get_settings():
settings['multiple_keywords'] = 'keyword1, keyword2'
settings['ti_keywords'] = 'None'
settings['keyword_order'] = 'textual inversion first'
settings['lora_model'] = 'None'
settings['lora_multiplier'] = 0.7
settings['lora_keywords'] = 'None'
settings_file = f'{scripts_dir}/settings.txt'
@ -149,6 +238,9 @@ class Script(scripts.Script):
def get_embeddings():
import glob
return [os.path.basename(x) for x in glob.glob(f'{shared.cmd_opts.embeddings_dir}/*.pt')]
def get_loras():
import glob
return [os.path.basename(x) for x in glob.glob(f'{shared.cmd_opts.lora_dir}/*.safetensors')]
def update_keywords():
model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename)
@ -164,6 +256,15 @@ class Script(scripts.Script):
ti_choices = ["None"]
ti_choices.extend(get_embeddings())
return gr.Dropdown.update(choices=ti_choices)
def update_loras():
lora_choices = ["None"]
lora_choices.extend(get_loras())
return gr.Dropdown.update(choices=lora_choices)
def update_lora_keywords(lora_model):
lora_keywords = get_lora_keywords(lora_model)
return gr.Dropdown.update(choices=lora_keywords)
def check_keyword():
model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename)
@ -257,6 +358,92 @@ class Script(scripts.Script):
return 'added: ' + insert_line
def delete_lora_keyword(lora_model):
model_ckpt = lora_model
lora_model_path = f'{shared.cmd_opts.lora_dir}/{lora_model}'
model_hash = get_old_model_hash(lora_model_path)
user_file = f'{scripts_dir}/lora-keyword-user.txt'
user_backup_file = f'{scripts_dir}/lora-keyword-user-backup.txt'
lines = []
found = None
if os.path.exists(user_file):
with open(user_file, newline='') as csvfile:
csvreader = csv.reader(csvfile)
for row in csvreader:
try:
mhash = row[0]
if mhash.startswith('#'):
lines.append(','.join(row))
continue
# kw = row[1]
ckptname = None if len(row)<=2 else row[2].strip(' ')
if mhash==model_hash and ckptname==model_ckpt:
found = row
continue
lines.append(','.join(row))
except:
pass
outline = ''
if found:
csvtxt = '\n'.join(lines) + '\n'
import shutil
try:
shutil.copy(user_file, user_backup_file)
except:
pass
open(user_file, 'w').write(csvtxt)
outline = f'deleted entry: {found}'
else:
outline = f'no custom mapping found'
lora_keywords = get_lora_keywords(lora_model)
return [outline, gr.Dropdown.update(choices=lora_keywords)]
def add_lora_keyword(txt, lora_model):
txt = txt.strip()
model_ckpt = lora_model
lora_model_path = f'{shared.cmd_opts.lora_dir}/{lora_model}'
model_hash = get_old_model_hash(lora_model_path)
if len(txt) == 0:
return "Fill keyword(trigger word) or keywords separated by | (pipe character)"
insert_line = f'{model_hash}, {txt}, {model_ckpt}'
global scripts_dir
user_file = f'{scripts_dir}/lora-keyword-user.txt'
user_backup_file = f'{scripts_dir}/lora-keyword-user-backup.txt'
lines = []
if os.path.exists(user_file):
with open(user_file, newline='') as csvfile:
csvreader = csv.reader(csvfile)
for row in csvreader:
try:
mhash = row[0]
if mhash.startswith('#'):
lines.append(','.join(row))
continue
# kw = row[1]
ckptname = None if len(row)<=2 else row[2].strip(' ')
if mhash==model_hash and ckptname==model_ckpt:
continue
lines.append(','.join(row))
except:
pass
lines.append(insert_line)
csvtxt = '\n'.join(lines) + '\n'
import shutil
try:
shutil.copy(user_file, user_backup_file)
except:
pass
open(user_file, 'w').write(csvtxt)
lora_keywords = get_lora_keywords(lora_model)
return ['added: ' + insert_line, gr.Dropdown.update(choices=lora_keywords)]
settings = get_settings()
def cb_enabled():
@ -267,9 +454,16 @@ class Script(scripts.Script):
return settings['multiple_keywords']
def cb_ti_keywords():
return settings['ti_keywords']
def cb_lora_model():
return settings['lora_model']
def cb_lora_multiplier():
return settings['lora_multiplier']
def cb_lora_keywords():
return settings['lora_keywords']
def cb_keyword_order():
return settings['keyword_order']
refresh_icon = '\U0001f504'
with gr.Group():
with gr.Accordion('Model Keyword', open=False):
@ -293,7 +487,7 @@ class Script(scripts.Script):
label='Multiple keywords: ')
setattr(multiple_keywords,"do_not_save_to_config",True)
refresh_btn = gr.Button(value='\U0001f504', elem_id='mk_refresh_btn_random_seed') # XXX _random_seed workaround.
refresh_btn = gr.Button(value=refresh_icon, elem_id='mk_refresh_btn_random_seed') # XXX _random_seed workaround.
refresh_btn.click(update_keywords, inputs=None, outputs=multiple_keywords)
ti_choices = ["None"]
@ -303,7 +497,7 @@ class Script(scripts.Script):
value=cb_ti_keywords,
label='Textual Inversion (Embedding): ')
setattr(ti_keywords,"do_not_save_to_config",True)
refresh_btn = gr.Button(value='\U0001f504', elem_id='ti_refresh_btn_random_seed') # XXX _random_seed workaround.
refresh_btn = gr.Button(value=refresh_icon, elem_id='ti_refresh_btn_random_seed') # XXX _random_seed workaround.
refresh_btn.click(update_embeddings, inputs=None, outputs=ti_keywords)
keyword_order = gr.Dropdown(choices=["textual inversion first", "model keyword first"],
@ -311,6 +505,37 @@ class Script(scripts.Script):
label='Keyword order: ')
setattr(keyword_order,"do_not_save_to_config",True)
with gr.Accordion('LORA', open=True):
lora_choices = ["None"]
lora_choices.extend(get_loras())
lora_kw_choices = get_lora_keywords(settings['lora_model'])
with gr.Row(equal_height=True):
lora_model = gr.Dropdown(choices=lora_choices,
value=cb_lora_model,
label='Model: ')
setattr(lora_model,"do_not_save_to_config",True)
lora_refresh_btn = gr.Button(value=refresh_icon, elem_id='lora_m_refresh_btn_random_seed') # XXX _random_seed workaround.
lora_refresh_btn.click(update_loras, inputs=None, outputs=lora_model)
lora_multiplier = gr.Slider(minimum=0,maximum=2, step=0.01, value=cb_lora_multiplier, label="multiplier")
with gr.Row(equal_height=True):
lora_keywords = gr.Dropdown(choices=lora_kw_choices,
value=cb_lora_keywords,
label='keywords: ')
setattr(lora_keywords,"do_not_save_to_config",True)
lora_model.change(fn=update_lora_keywords,inputs=lora_model, outputs=lora_keywords)
info = gr.HTML("<p style=\"margin-bottom:0.75em\">Add custom keyword(trigger word) mapping for selected LORA model.</p>")
lora_text_input = gr.Textbox(placeholder="Keyword or keywords separated by |", label="Keyword(trigger word)")
with gr.Row():
add_mappings = gr.Button(value='Save')
delete_mappings = gr.Button(value='Delete')
lora_text_output = gr.Textbox(interactive=False, label='result')
add_mappings.click(add_lora_keyword, inputs=[lora_text_input, lora_model], outputs=[lora_text_output, lora_keywords])
delete_mappings.click(delete_lora_keyword, inputs=lora_model, outputs=[lora_text_output, lora_keywords])
with gr.Accordion('Add Custom Mappings', open=False):
info = gr.HTML("<p style=\"margin-bottom:0.75em\">Add custom keyword(trigger word) mapping for current model. Custom mappings are saved to extensions/model-keyword/custom-mappings.txt</p>")
text_input = gr.Textbox(placeholder="Keyword or keywords separated by |", label="Keyword(trigger word)")
@ -326,15 +551,22 @@ class Script(scripts.Script):
delete_mappings.click(delete_keyword, inputs=None, outputs=text_output)
return [is_enabled, keyword_placement, multiple_keywords, ti_keywords, keyword_order]
return [is_enabled, keyword_placement, multiple_keywords, ti_keywords, keyword_order, lora_model, lora_multiplier, lora_keywords]
def process(self, p, is_enabled, keyword_placement, multiple_keywords, ti_keywords, keyword_order, lora_model, lora_multiplier, lora_keywords):
if lora_model != 'None':
if lora_keywords not in get_lora_keywords(lora_model):
lora_keywords = 'None'
def process(self, p, is_enabled, keyword_placement, multiple_keywords, ti_keywords, keyword_order):
save_settings({
'is_enabled': f'{is_enabled}',
'keyword_placement': keyword_placement,
'multiple_keywords': multiple_keywords,
'ti_keywords': ti_keywords,
'keyword_order': keyword_order,
'lora_model': lora_model,
'lora_multiplier': lora_multiplier,
'lora_keywords': lora_keywords,
})
if not is_enabled:
@ -347,7 +579,7 @@ class Script(scripts.Script):
# print(f'model_hash = {model_hash}')
def new_prompt(prompt, kw, no_iter=False):
global kw_idx
global kw_idx, lora_idx
if kw:
kws = kw.split('|')
if len(kws) > 1:
@ -369,28 +601,44 @@ class Script(scripts.Script):
else:
kw = kws[0]
if ti_keywords == 'None':
arr = [kw]
else:
arr = [kw]
ti = None
if ti_keywords != 'None':
ti = ti_keywords[:ti_keywords.rfind('.')]
if keyword_order == 'model keyword first':
arr = [kw, ti]
else:
arr = [ti, kw]
if None in arr:
arr.remove(None)
lora = None
if lora_keywords != 'None' and lora_model != 'None':
lora = lora_keywords
try:
if lora == '< iterate >':
loras = get_lora_keywords(lora_model, keyword_only=True)
lora = loras[lora_idx%len(loras)]
if not no_iter:
lora_idx += 1
elif lora == '< random >':
loras = get_lora_keywords(lora_model, keyword_only=True)
lora = random.choice(loras)
except:
pass
if ',' in keyword_placement:
kw = ', '.join(arr)
else:
kw = ' '.join(arr)
if keyword_order == 'model keyword first':
arr = [kw, lora, ti]
else:
arr = [ti, lora, kw]
while None in arr:
arr.remove(None)
if keyword_placement.startswith('keyword'):
arr.append(prompt)
else:
arr.insert(0, prompt)
if lora_model != 'None':
lora_name = lora_model[:lora_model.rfind('.')]
arr.insert(0, f'<lora:{lora_name}:{lora_multiplier}>')
if ',' in keyword_placement:
return ', '.join(arr)
else:
@ -399,7 +647,7 @@ class Script(scripts.Script):
kw = get_keyword_for_model(model_hash, model_ckpt)
if kw is not None or ti_keywords != 'None':
if kw is not None or ti_keywords != 'None' or lora_model != 'None':
p.prompt = new_prompt(p.prompt, kw, no_iter=True)
p.all_prompts = [new_prompt(prompt, kw) for prompt in p.all_prompts]