import modules.scripts as scripts import gradio as gr import csv import os from collections import defaultdict import modules.shared as shared import difflib import random scripts_dir = scripts.basedir() kw_idx = 0 hash_dict = None hash_dict_modified = None model_hash_dict = {} def str_simularity(a, b): return difflib.SequenceMatcher(None, a, b).ratio() def get_old_model_hash(filename): if filename in model_hash_dict: return model_hash_dict[filename] try: with open(filename, "rb") as file: import hashlib m = hashlib.sha256() file.seek(0x100000) m.update(file.read(0x10000)) hash = m.hexdigest()[0:8] model_hash_dict[filename] = hash return hash except FileNotFoundError: return 'NOFILE' def load_hash_dict(): global hash_dict, hash_dict_modified, scripts_dir default_file = f'{scripts_dir}/model-keyword.txt' user_file = f'{scripts_dir}/custom-mappings.txt' if not os.path.exists(user_file): open(user_file, 'w').write('\n') modified = str(os.stat(default_file).st_mtime) + '_' + str(os.stat(user_file).st_mtime) if hash_dict is None or hash_dict_modified != modified: hash_dict = defaultdict(list) def parse_file(path, idx): if os.path.exists(path): with open(path, newline='') as csvfile: csvreader = csv.reader(csvfile) for row in csvreader: try: mhash = row[0].strip(' ') kw = row[1].strip(' ') if mhash.startswith('#'): continue ckptname = 'default' if len(row)<=2 else row[2].strip(' ') hash_dict[mhash].append((kw, ckptname,idx)) except: pass parse_file(default_file, 0) # 0 for default_file parse_file(user_file, 1) # 1 for user_file hash_dict_modified = modified return hash_dict def get_keyword_for_model(model_hash, model_ckpt, return_entry=False): found = None # hash -> [ (keyword, ckptname, idx) ] hash_dict = load_hash_dict() # print(hash_dict) if model_hash in hash_dict: lst = hash_dict[model_hash] if len(lst) == 1: found = lst[0] elif len(lst) > 1: max_sim = 0.0 found = lst[0] for kw_ckpt in lst: sim = str_simularity(kw_ckpt[1], model_ckpt) if sim >= max_sim: max_sim = sim found = kw_ckpt if return_entry: return found return found[0] if found else None settings = None def save_settings(m): global scripts_dir, settings if settings is None: settigns = get_settings() for k in m.keys(): settings[k] = m[k] settings_file = f'{scripts_dir}/settings.txt' lines = [] for k in settings.keys(): lines.append(f'{k}={settings[k]}') csvtxt = '\n'.join(lines)+'\n' open(settings_file, 'w').write(csvtxt) def get_settings(): global scripts_dir, settings if settings: return settings settings = {} settings['is_enabled'] = 'True' settings['keyword_placement'] = 'keyword prompt' settings['multiple_keywords'] = 'keyword1, keyword2' settings['ti_keywords'] = 'None' settings['keyword_order'] = 'textual inversion first' settings_file = f'{scripts_dir}/settings.txt' if os.path.exists(settings_file): with open(settings_file, newline='') as file: for line in file.read().split('\n'): pos = line.find('=') if pos == -1: continue k = line[:pos] v = line[pos+1:].strip() settings[k] = v return settings class Script(scripts.Script): def title(self): return "Model keyword" def show(self, is_img2img): return scripts.AlwaysVisible def ui(self, is_img2img): def get_embeddings(): import glob return [os.path.basename(x) for x in glob.glob(f'{shared.cmd_opts.embeddings_dir}/*.pt')] def update_keywords(): model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename) model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename) kws = get_keyword_for_model(model_hash, model_ckpt) mk_choices = ["keyword1, keyword2", "random", "iterate"] if kws: mk_choices.extend([x.strip() for x in kws.split('|')]) else: mk_choices.extend(["keyword1", "keyword2"]) return gr.Dropdown.update(choices=mk_choices) def update_embeddings(): ti_choices = ["None"] ti_choices.extend(get_embeddings()) return gr.Dropdown.update(choices=ti_choices) def check_keyword(): model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename) model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename) entry = get_keyword_for_model(model_hash, model_ckpt, return_entry=True) if entry: kw = entry[0] src = 'custom-mappings.txt' if entry[2]==1 else 'model-keyword.txt (default database)' return f"filename={model_ckpt}\nhash={model_hash}\nkeyword={kw}\nmatch from {src}" else: return f"filename={model_ckpt}\nhash={model_hash}\nno match" def delete_keyword(): model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename) model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename) user_file = f'{scripts_dir}/custom-mappings.txt' user_backup_file = f'{scripts_dir}/custom-mappings-backup.txt' lines = [] found = None if os.path.exists(user_file): with open(user_file, newline='') as csvfile: csvreader = csv.reader(csvfile) for row in csvreader: try: mhash = row[0] if mhash.startswith('#'): lines.append(','.join(row)) continue # kw = row[1] ckptname = None if len(row)<=2 else row[2].strip(' ') if mhash==model_hash and ckptname==model_ckpt: found = row continue lines.append(','.join(row)) except: pass if found: csvtxt = '\n'.join(lines) + '\n' import shutil try: shutil.copy(user_file, user_backup_file) except: pass open(user_file, 'w').write(csvtxt) return f'deleted entry: {found}' else: return f'no custom mapping found' def add_custom(txt): txt = txt.strip() model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename) model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename) if len(txt) == 0: return "Fill keyword(trigger word) or keywords separated by | (pipe character)" insert_line = f'{model_hash}, {txt}, {model_ckpt}' global scripts_dir user_file = f'{scripts_dir}/custom-mappings.txt' user_backup_file = f'{scripts_dir}/custom-mappings-backup.txt' lines = [] if os.path.exists(user_file): with open(user_file, newline='') as csvfile: csvreader = csv.reader(csvfile) for row in csvreader: try: mhash = row[0] if mhash.startswith('#'): lines.append(','.join(row)) continue # kw = row[1] ckptname = None if len(row)<=2 else row[2].strip(' ') if mhash==model_hash and ckptname==model_ckpt: continue lines.append(','.join(row)) except: pass lines.append(insert_line) csvtxt = '\n'.join(lines) + '\n' import shutil try: shutil.copy(user_file, user_backup_file) except: pass open(user_file, 'w').write(csvtxt) return 'added: ' + insert_line settings = get_settings() def cb_enabled(): return True if settings['is_enabled']=='True' else False def cb_keyword_placement(): return settings['keyword_placement'] def cb_multiple_keywords(): return settings['multiple_keywords'] def cb_ti_keywords(): return settings['ti_keywords'] def cb_keyword_order(): return settings['keyword_order'] with gr.Group(): with gr.Accordion('Model Keyword', open=False): is_enabled = gr.Checkbox(label='Model Keyword Enabled ', value=cb_enabled) setattr(is_enabled,"do_not_save_to_config",True) placement_choices = ["keyword prompt", "prompt keyword", "keyword, prompt", "prompt, keyword"] keyword_placement = gr.Dropdown(choices=placement_choices, value=cb_keyword_placement, label='Keyword placement: ') setattr(keyword_placement,"do_not_save_to_config",True) mk_choices = ["keyword1, keyword2", "random", "iterate"] mk_choices.extend(["keyword1", "keyword2"]) # css = '#mk_refresh_btn { min-width: 2.3em; height: 2.5em; flex-grow: 0; margin-top: 0.4em; margin-right: 1em; padding-left: 0.25em; padding-right: 0.25em;}' # with gr.Blocks(css=css): with gr.Row(equal_height=True): multiple_keywords = gr.Dropdown(choices=mk_choices, value=cb_multiple_keywords, label='Multiple keywords: ') setattr(multiple_keywords,"do_not_save_to_config",True) refresh_btn = gr.Button(value='\U0001f504', elem_id='mk_refresh_btn_random_seed') # XXX _random_seed workaround. refresh_btn.click(update_keywords, inputs=None, outputs=multiple_keywords) ti_choices = ["None"] ti_choices.extend(get_embeddings()) with gr.Row(equal_height=True): ti_keywords = gr.Dropdown(choices=ti_choices, value=cb_ti_keywords, label='Textual Inversion (Embedding): ') setattr(ti_keywords,"do_not_save_to_config",True) refresh_btn = gr.Button(value='\U0001f504', elem_id='ti_refresh_btn_random_seed') # XXX _random_seed workaround. refresh_btn.click(update_embeddings, inputs=None, outputs=ti_keywords) keyword_order = gr.Dropdown(choices=["textual inversion first", "model keyword first"], value=cb_keyword_order, label='Keyword order: ') setattr(keyword_order,"do_not_save_to_config",True) with gr.Accordion('Add Custom Mappings', open=False): info = gr.HTML("
Add custom keyword(trigger word) mapping for current model. Custom mappings are saved to extensions/model-keyword/custom-mappings.txt
") text_input = gr.Textbox(placeholder="Keyword or keywords separated by |", label="Keyword(trigger word)") with gr.Row(): check_mappings = gr.Button(value='Check') add_mappings = gr.Button(value='Save') delete_mappings = gr.Button(value='Delete') text_output = gr.Textbox(interactive=False, label='result') add_mappings.click(add_custom, inputs=text_input, outputs=text_output) check_mappings.click(check_keyword, inputs=None, outputs=text_output) delete_mappings.click(delete_keyword, inputs=None, outputs=text_output) return [is_enabled, keyword_placement, multiple_keywords, ti_keywords, keyword_order] def process(self, p, is_enabled, keyword_placement, multiple_keywords, ti_keywords, keyword_order): save_settings({ 'is_enabled': f'{is_enabled}', 'keyword_placement': keyword_placement, 'multiple_keywords': multiple_keywords, 'ti_keywords': ti_keywords, 'keyword_order': keyword_order, }) settings = get_settings() print(f'settings2 = {settings}') if not is_enabled: global hash_dict hash_dict = None return model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename) model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename) # print(f'model_hash = {model_hash}') def new_prompt(prompt, kw, no_iter=False): global kw_idx if kw: kws = kw.split('|') if len(kws) > 1: kws = [x.strip(' ') for x in kws] if multiple_keywords=="keyword1, keyword2": kw = ', '.join(kws) elif multiple_keywords=="random": kw = random.choice(kws) elif multiple_keywords=="iterate": kw = kws[kw_idx%len(kws)] if not no_iter: kw_idx += 1 elif multiple_keywords=="keyword1": kw = kws[0] elif multiple_keywords=="keyword2": kw = kws[1] elif multiple_keywords in kws: kw = multiple_keywords else: kw = kws[0] if ti_keywords == 'None': arr = [kw] else: ti = ti_keywords[:ti_keywords.rfind('.')] if keyword_order == 'model keyword first': arr = [kw, ti] else: arr = [ti, kw] if None in arr: arr.remove(None) if ',' in keyword_placement: kw = ', '.join(arr) else: kw = ' '.join(arr) if keyword_placement.startswith('keyword'): arr.append(prompt) else: arr.insert(0, prompt) if ',' in keyword_placement: return ', '.join(arr) else: return ' '.join(arr) kw = get_keyword_for_model(model_hash, model_ckpt) if kw is not None or ti_keywords != 'None': p.prompt = new_prompt(p.prompt, kw, no_iter=True) p.all_prompts = [new_prompt(prompt, kw) for prompt in p.all_prompts] from fastapi import FastAPI, Response, Query, Body from fastapi.responses import JSONResponse def model_keyword_api(_: gr.Blocks, app: FastAPI): @app.get("/model_keyword/get_keywords") async def get_keywords(): model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename) model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename) r = get_keyword_for_model(model_hash, model_ckpt, return_entry=True) if r is None: return {"keywords": [], "model": model_ckpt, "hash": model_hash, "match_source": "no match"} kws = [x.strip() for x in r[0].split('|')] match_source = "model-keyword.txt" if r[2]==0 else "custom-mappings.txt" return {"keywords": kws, "model": model_ckpt, "hash": model_hash, "match_source": match_source} # @app.get("/model_keyword/get_raw_keywords") # async def get_raw_keywords(): # model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename) # model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename) # kw = get_keyword_for_model(model_hash, model_ckpt) # return {"keywords": kw, "model": model_ckpt, "hash": model_hash} try: import modules.script_callbacks as script_callbacks script_callbacks.on_app_started(model_keyword_api) except: pass