add api "/model_keyword/get_keywords".
parent
fac289cc52
commit
aad7369165
|
|
@ -33,6 +33,61 @@ def get_old_model_hash(filename):
|
|||
except FileNotFoundError:
|
||||
return 'NOFILE'
|
||||
|
||||
def load_hash_dict():
|
||||
global hash_dict, hash_dict_modified, scripts_dir
|
||||
|
||||
default_file = f'{scripts_dir}/model-keyword.txt'
|
||||
user_file = f'{scripts_dir}/custom-mappings.txt'
|
||||
modified = str(os.stat(default_file).st_mtime) + '_' + str(os.stat(user_file).st_mtime)
|
||||
|
||||
if hash_dict is None or hash_dict_modified != modified:
|
||||
hash_dict = defaultdict(list)
|
||||
def parse_file(path, idx):
|
||||
if os.path.exists(path):
|
||||
with open(path, newline='') as csvfile:
|
||||
csvreader = csv.reader(csvfile)
|
||||
for row in csvreader:
|
||||
try:
|
||||
mhash = row[0].strip(' ')
|
||||
kw = row[1].strip(' ')
|
||||
if mhash.startswith('#'):
|
||||
continue
|
||||
ckptname = 'default' if len(row)<=2 else row[2].strip(' ')
|
||||
hash_dict[mhash].append((kw, ckptname,idx))
|
||||
except:
|
||||
pass
|
||||
|
||||
parse_file(default_file, 0) # 0 for default_file
|
||||
parse_file(user_file, 1) # 1 for user_file
|
||||
|
||||
hash_dict_modified = modified
|
||||
|
||||
return hash_dict
|
||||
|
||||
def get_keyword_for_model(model_hash, model_ckpt, return_entry=False):
|
||||
found = None
|
||||
|
||||
# hash -> [ (keyword, ckptname, idx) ]
|
||||
hash_dict = load_hash_dict()
|
||||
|
||||
# print(hash_dict)
|
||||
|
||||
if model_hash in hash_dict:
|
||||
lst = hash_dict[model_hash]
|
||||
if len(lst) == 1:
|
||||
found = lst[0]
|
||||
|
||||
elif len(lst) > 1:
|
||||
max_sim = 0.0
|
||||
found = lst[0]
|
||||
for kw_ckpt in lst:
|
||||
sim = str_simularity(kw_ckpt[1], model_ckpt)
|
||||
if sim >= max_sim:
|
||||
max_sim = sim
|
||||
found = kw_ckpt
|
||||
if return_entry:
|
||||
return found
|
||||
return found[0] if found else None
|
||||
|
||||
class Script(scripts.Script):
|
||||
def title(self):
|
||||
|
|
@ -49,7 +104,7 @@ class Script(scripts.Script):
|
|||
def update_keywords():
|
||||
model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename)
|
||||
model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename)
|
||||
kws = self.get_keyword(model_hash, model_ckpt)
|
||||
kws = get_keyword_for_model(model_hash, model_ckpt)
|
||||
mk_choices = ["keyword1, keyword2", "random", "iterate"]
|
||||
if kws:
|
||||
mk_choices.extend([x.strip() for x in kws.split('|')])
|
||||
|
|
@ -64,8 +119,7 @@ class Script(scripts.Script):
|
|||
def check_keyword():
|
||||
model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename)
|
||||
model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename)
|
||||
# hash_dict = self.load_hash_dict()
|
||||
entry = self.get_keyword(model_hash, model_ckpt, return_entry=True)
|
||||
entry = get_keyword_for_model(model_hash, model_ckpt, return_entry=True)
|
||||
|
||||
if entry:
|
||||
kw = entry[0]
|
||||
|
|
@ -77,7 +131,6 @@ class Script(scripts.Script):
|
|||
def delete_keyword():
|
||||
model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename)
|
||||
model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename)
|
||||
# hash_dict = self.load_hash_dict()
|
||||
user_file = f'{scripts_dir}/custom-mappings.txt'
|
||||
user_backup_file = f'{scripts_dir}/custom-mappings-backup.txt'
|
||||
lines = []
|
||||
|
|
@ -205,65 +258,6 @@ class Script(scripts.Script):
|
|||
|
||||
return [is_enabled, keyword_placement, multiple_keywords, ti_keywords, keyword_order]
|
||||
|
||||
def load_hash_dict(self):
|
||||
global hash_dict, hash_dict_modified, scripts_dir
|
||||
|
||||
default_file = f'{scripts_dir}/model-keyword.txt'
|
||||
user_file = f'{scripts_dir}/custom-mappings.txt'
|
||||
if not os.path.exists(user_file):
|
||||
open(user_file, 'w').write('\n')
|
||||
modified = str(os.stat(default_file).st_mtime) + '_' + str(os.stat(user_file).st_mtime)
|
||||
|
||||
if hash_dict is None or hash_dict_modified != modified:
|
||||
hash_dict = defaultdict(list)
|
||||
def parse_file(path, idx):
|
||||
if os.path.exists(path):
|
||||
with open(path, newline='') as csvfile:
|
||||
csvreader = csv.reader(csvfile)
|
||||
for row in csvreader:
|
||||
try:
|
||||
mhash = row[0].strip(' ')
|
||||
kw = row[1].strip(' ')
|
||||
if mhash.startswith('#'):
|
||||
continue
|
||||
ckptname = 'default' if len(row)<=2 else row[2].strip(' ')
|
||||
hash_dict[mhash].append((kw, ckptname,idx))
|
||||
except:
|
||||
pass
|
||||
|
||||
parse_file(default_file, 0) # 0 for default_file
|
||||
parse_file(user_file, 1) # 1 for user_file
|
||||
|
||||
hash_dict_modified = modified
|
||||
|
||||
return hash_dict
|
||||
|
||||
def get_keyword(self, model_hash, model_ckpt, return_entry=False):
|
||||
found = None
|
||||
|
||||
# hash -> [ (keyword, ckptname, idx) ]
|
||||
hash_dict = self.load_hash_dict()
|
||||
|
||||
# print(hash_dict)
|
||||
|
||||
if model_hash in hash_dict:
|
||||
lst = hash_dict[model_hash]
|
||||
if len(lst) == 1:
|
||||
found = lst[0]
|
||||
|
||||
elif len(lst) > 1:
|
||||
max_sim = 0.0
|
||||
found = lst[0]
|
||||
for kw_ckpt in lst:
|
||||
sim = str_simularity(kw_ckpt[1], model_ckpt)
|
||||
if sim >= max_sim:
|
||||
max_sim = sim
|
||||
found = kw_ckpt
|
||||
if return_entry:
|
||||
return found
|
||||
return found[0] if found else None
|
||||
|
||||
|
||||
def process(self, p, is_enabled, keyword_placement, multiple_keywords, ti_keywords, keyword_order):
|
||||
|
||||
if not is_enabled:
|
||||
|
|
@ -326,8 +320,35 @@ class Script(scripts.Script):
|
|||
return ' '.join(arr)
|
||||
|
||||
|
||||
kw = self.get_keyword(model_hash, model_ckpt)
|
||||
kw = get_keyword_for_model(model_hash, model_ckpt)
|
||||
|
||||
if kw is not None or ti_keywords != 'None':
|
||||
p.prompt = new_prompt(p.prompt, kw, no_iter=True)
|
||||
p.all_prompts = [new_prompt(prompt, kw) for prompt in p.all_prompts]
|
||||
|
||||
|
||||
from fastapi import FastAPI, Response, Query, Body
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
|
||||
|
||||
|
||||
def model_keyword_api(_: gr.Blocks, app: FastAPI):
|
||||
@app.get("/model_keyword/get_keywords")
|
||||
async def get_keywords():
|
||||
model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename)
|
||||
model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename)
|
||||
kws = [x.strip() for x in get_keyword_for_model(model_hash, model_ckpt).split('|')]
|
||||
return {"keywords": kws, "model": model_ckpt, "hash": model_hash}
|
||||
|
||||
# @app.get("/model_keyword/get_raw_keywords")
|
||||
# async def get_raw_keywords():
|
||||
# model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename)
|
||||
# model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename)
|
||||
# kw = get_keyword_for_model(model_hash, model_ckpt)
|
||||
# return {"keywords": kw, "model": model_ckpt, "hash": model_hash}
|
||||
|
||||
try:
|
||||
import modules.script_callbacks as script_callbacks
|
||||
|
||||
script_callbacks.on_app_started(model_keyword_api)
|
||||
except:
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in New Issue