add api "/model_keyword/get_keywords".

pull/33/head
ChunKoo Park 2023-01-20 17:50:41 +09:00
parent fac289cc52
commit aad7369165
1 changed files with 85 additions and 64 deletions

View File

@ -33,6 +33,61 @@ def get_old_model_hash(filename):
except FileNotFoundError:
return 'NOFILE'
def load_hash_dict():
global hash_dict, hash_dict_modified, scripts_dir
default_file = f'{scripts_dir}/model-keyword.txt'
user_file = f'{scripts_dir}/custom-mappings.txt'
modified = str(os.stat(default_file).st_mtime) + '_' + str(os.stat(user_file).st_mtime)
if hash_dict is None or hash_dict_modified != modified:
hash_dict = defaultdict(list)
def parse_file(path, idx):
if os.path.exists(path):
with open(path, newline='') as csvfile:
csvreader = csv.reader(csvfile)
for row in csvreader:
try:
mhash = row[0].strip(' ')
kw = row[1].strip(' ')
if mhash.startswith('#'):
continue
ckptname = 'default' if len(row)<=2 else row[2].strip(' ')
hash_dict[mhash].append((kw, ckptname,idx))
except:
pass
parse_file(default_file, 0) # 0 for default_file
parse_file(user_file, 1) # 1 for user_file
hash_dict_modified = modified
return hash_dict
def get_keyword_for_model(model_hash, model_ckpt, return_entry=False):
found = None
# hash -> [ (keyword, ckptname, idx) ]
hash_dict = load_hash_dict()
# print(hash_dict)
if model_hash in hash_dict:
lst = hash_dict[model_hash]
if len(lst) == 1:
found = lst[0]
elif len(lst) > 1:
max_sim = 0.0
found = lst[0]
for kw_ckpt in lst:
sim = str_simularity(kw_ckpt[1], model_ckpt)
if sim >= max_sim:
max_sim = sim
found = kw_ckpt
if return_entry:
return found
return found[0] if found else None
class Script(scripts.Script):
def title(self):
@ -49,7 +104,7 @@ class Script(scripts.Script):
def update_keywords():
model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename)
model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename)
kws = self.get_keyword(model_hash, model_ckpt)
kws = get_keyword_for_model(model_hash, model_ckpt)
mk_choices = ["keyword1, keyword2", "random", "iterate"]
if kws:
mk_choices.extend([x.strip() for x in kws.split('|')])
@ -64,8 +119,7 @@ class Script(scripts.Script):
def check_keyword():
model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename)
model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename)
# hash_dict = self.load_hash_dict()
entry = self.get_keyword(model_hash, model_ckpt, return_entry=True)
entry = get_keyword_for_model(model_hash, model_ckpt, return_entry=True)
if entry:
kw = entry[0]
@ -77,7 +131,6 @@ class Script(scripts.Script):
def delete_keyword():
model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename)
model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename)
# hash_dict = self.load_hash_dict()
user_file = f'{scripts_dir}/custom-mappings.txt'
user_backup_file = f'{scripts_dir}/custom-mappings-backup.txt'
lines = []
@ -205,65 +258,6 @@ class Script(scripts.Script):
return [is_enabled, keyword_placement, multiple_keywords, ti_keywords, keyword_order]
def load_hash_dict(self):
global hash_dict, hash_dict_modified, scripts_dir
default_file = f'{scripts_dir}/model-keyword.txt'
user_file = f'{scripts_dir}/custom-mappings.txt'
if not os.path.exists(user_file):
open(user_file, 'w').write('\n')
modified = str(os.stat(default_file).st_mtime) + '_' + str(os.stat(user_file).st_mtime)
if hash_dict is None or hash_dict_modified != modified:
hash_dict = defaultdict(list)
def parse_file(path, idx):
if os.path.exists(path):
with open(path, newline='') as csvfile:
csvreader = csv.reader(csvfile)
for row in csvreader:
try:
mhash = row[0].strip(' ')
kw = row[1].strip(' ')
if mhash.startswith('#'):
continue
ckptname = 'default' if len(row)<=2 else row[2].strip(' ')
hash_dict[mhash].append((kw, ckptname,idx))
except:
pass
parse_file(default_file, 0) # 0 for default_file
parse_file(user_file, 1) # 1 for user_file
hash_dict_modified = modified
return hash_dict
def get_keyword(self, model_hash, model_ckpt, return_entry=False):
found = None
# hash -> [ (keyword, ckptname, idx) ]
hash_dict = self.load_hash_dict()
# print(hash_dict)
if model_hash in hash_dict:
lst = hash_dict[model_hash]
if len(lst) == 1:
found = lst[0]
elif len(lst) > 1:
max_sim = 0.0
found = lst[0]
for kw_ckpt in lst:
sim = str_simularity(kw_ckpt[1], model_ckpt)
if sim >= max_sim:
max_sim = sim
found = kw_ckpt
if return_entry:
return found
return found[0] if found else None
def process(self, p, is_enabled, keyword_placement, multiple_keywords, ti_keywords, keyword_order):
if not is_enabled:
@ -326,8 +320,35 @@ class Script(scripts.Script):
return ' '.join(arr)
kw = self.get_keyword(model_hash, model_ckpt)
kw = get_keyword_for_model(model_hash, model_ckpt)
if kw is not None or ti_keywords != 'None':
p.prompt = new_prompt(p.prompt, kw, no_iter=True)
p.all_prompts = [new_prompt(prompt, kw) for prompt in p.all_prompts]
from fastapi import FastAPI, Response, Query, Body
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
def model_keyword_api(_: gr.Blocks, app: FastAPI):
@app.get("/model_keyword/get_keywords")
async def get_keywords():
model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename)
model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename)
kws = [x.strip() for x in get_keyword_for_model(model_hash, model_ckpt).split('|')]
return {"keywords": kws, "model": model_ckpt, "hash": model_hash}
# @app.get("/model_keyword/get_raw_keywords")
# async def get_raw_keywords():
# model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename)
# model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename)
# kw = get_keyword_for_model(model_hash, model_ckpt)
# return {"keywords": kw, "model": model_ckpt, "hash": model_hash}
try:
import modules.script_callbacks as script_callbacks
script_callbacks.on_app_started(model_keyword_api)
except:
pass