support LORA models in subfolders.
parent
a1f920b8ad
commit
071bd3fc29
|
|
@ -7,6 +7,10 @@ from collections import defaultdict
|
|||
import modules.shared as shared
|
||||
import difflib
|
||||
import random
|
||||
import glob
|
||||
import hashlib
|
||||
import shutil
|
||||
import fnmatch
|
||||
|
||||
scripts_dir = scripts.basedir()
|
||||
kw_idx = 0
|
||||
|
|
@ -26,7 +30,6 @@ def get_old_model_hash(filename):
|
|||
return model_hash_dict[filename]
|
||||
try:
|
||||
with open(filename, "rb") as file:
|
||||
import hashlib
|
||||
m = hashlib.sha256()
|
||||
|
||||
file.seek(0x100000)
|
||||
|
|
@ -37,6 +40,13 @@ def get_old_model_hash(filename):
|
|||
except FileNotFoundError:
|
||||
return 'NOFILE'
|
||||
|
||||
def find_files(directory, exts):
|
||||
for root, dirs, files in os.walk(directory):
|
||||
for ext in exts:
|
||||
pattern = f'*.{ext}'
|
||||
for filename in fnmatch.filter(files, pattern):
|
||||
yield os.path.relpath(os.path.join(root, filename), directory)
|
||||
|
||||
def load_hash_dict():
|
||||
global hash_dict, hash_dict_modified, scripts_dir
|
||||
|
||||
|
|
@ -143,7 +153,6 @@ def _get_keywords_for_lora(lora_model, return_entry=False):
|
|||
lora_hash_dict = load_lora_hash_dict()
|
||||
|
||||
lora_model_hash = get_old_model_hash(lora_model_path)
|
||||
# print(hash_dict)
|
||||
|
||||
if lora_model_hash in lora_hash_dict:
|
||||
lst = lora_hash_dict[lora_model_hash]
|
||||
|
|
@ -154,7 +163,7 @@ def _get_keywords_for_lora(lora_model, return_entry=False):
|
|||
max_sim = 0.0
|
||||
found = lst[0]
|
||||
for kw_ckpt in lst:
|
||||
sim = str_simularity(kw_ckpt[1], model_ckpt)
|
||||
sim = str_simularity(kw_ckpt[1], lora_model)
|
||||
if sim >= max_sim:
|
||||
max_sim = sim
|
||||
found = kw_ckpt
|
||||
|
|
@ -236,11 +245,10 @@ class Script(scripts.Script):
|
|||
|
||||
def ui(self, is_img2img):
|
||||
def get_embeddings():
|
||||
import glob
|
||||
return [os.path.basename(x) for x in glob.glob(f'{shared.cmd_opts.embeddings_dir}/*.pt')]
|
||||
def get_loras():
|
||||
import glob
|
||||
return [os.path.basename(x) for x in glob.glob(f'{shared.cmd_opts.lora_dir}/*.safetensors')]
|
||||
return list(find_files(shared.cmd_opts.lora_dir,['safetensors','ckpt','pt']))
|
||||
# return [os.path.basename(x) for x in glob.glob(f'{shared.cmd_opts.lora_dir}/*.safetensors')]
|
||||
|
||||
def update_keywords():
|
||||
model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename)
|
||||
|
|
@ -306,7 +314,6 @@ class Script(scripts.Script):
|
|||
|
||||
if found:
|
||||
csvtxt = '\n'.join(lines) + '\n'
|
||||
import shutil
|
||||
try:
|
||||
shutil.copy(user_file, user_backup_file)
|
||||
except:
|
||||
|
|
@ -349,7 +356,6 @@ class Script(scripts.Script):
|
|||
pass
|
||||
lines.append(insert_line)
|
||||
csvtxt = '\n'.join(lines) + '\n'
|
||||
import shutil
|
||||
try:
|
||||
shutil.copy(user_file, user_backup_file)
|
||||
except:
|
||||
|
|
@ -388,7 +394,6 @@ class Script(scripts.Script):
|
|||
outline = ''
|
||||
if found:
|
||||
csvtxt = '\n'.join(lines) + '\n'
|
||||
import shutil
|
||||
try:
|
||||
shutil.copy(user_file, user_backup_file)
|
||||
except:
|
||||
|
|
@ -434,7 +439,6 @@ class Script(scripts.Script):
|
|||
pass
|
||||
lines.append(insert_line)
|
||||
csvtxt = '\n'.join(lines) + '\n'
|
||||
import shutil
|
||||
try:
|
||||
shutil.copy(user_file, user_backup_file)
|
||||
except:
|
||||
|
|
|
|||
Loading…
Reference in New Issue