support LORA models in subfolders.

pull/33/head
ChunKoo Park 2023-02-25 23:13:02 +09:00
parent a1f920b8ad
commit 071bd3fc29
1 changed files with 14 additions and 10 deletions

View File

@ -7,6 +7,10 @@ from collections import defaultdict
import modules.shared as shared
import difflib
import random
import glob
import hashlib
import shutil
import fnmatch
scripts_dir = scripts.basedir()
kw_idx = 0
@ -26,7 +30,6 @@ def get_old_model_hash(filename):
return model_hash_dict[filename]
try:
with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256()
file.seek(0x100000)
@ -37,6 +40,13 @@ def get_old_model_hash(filename):
except FileNotFoundError:
return 'NOFILE'
def find_files(directory, exts):
for root, dirs, files in os.walk(directory):
for ext in exts:
pattern = f'*.{ext}'
for filename in fnmatch.filter(files, pattern):
yield os.path.relpath(os.path.join(root, filename), directory)
def load_hash_dict():
global hash_dict, hash_dict_modified, scripts_dir
@ -143,7 +153,6 @@ def _get_keywords_for_lora(lora_model, return_entry=False):
lora_hash_dict = load_lora_hash_dict()
lora_model_hash = get_old_model_hash(lora_model_path)
# print(hash_dict)
if lora_model_hash in lora_hash_dict:
lst = lora_hash_dict[lora_model_hash]
@ -154,7 +163,7 @@ def _get_keywords_for_lora(lora_model, return_entry=False):
max_sim = 0.0
found = lst[0]
for kw_ckpt in lst:
sim = str_simularity(kw_ckpt[1], model_ckpt)
sim = str_simularity(kw_ckpt[1], lora_model)
if sim >= max_sim:
max_sim = sim
found = kw_ckpt
@ -236,11 +245,10 @@ class Script(scripts.Script):
def ui(self, is_img2img):
def get_embeddings():
import glob
return [os.path.basename(x) for x in glob.glob(f'{shared.cmd_opts.embeddings_dir}/*.pt')]
def get_loras():
import glob
return [os.path.basename(x) for x in glob.glob(f'{shared.cmd_opts.lora_dir}/*.safetensors')]
return list(find_files(shared.cmd_opts.lora_dir,['safetensors','ckpt','pt']))
# return [os.path.basename(x) for x in glob.glob(f'{shared.cmd_opts.lora_dir}/*.safetensors')]
def update_keywords():
model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename)
@ -306,7 +314,6 @@ class Script(scripts.Script):
if found:
csvtxt = '\n'.join(lines) + '\n'
import shutil
try:
shutil.copy(user_file, user_backup_file)
except:
@ -349,7 +356,6 @@ class Script(scripts.Script):
pass
lines.append(insert_line)
csvtxt = '\n'.join(lines) + '\n'
import shutil
try:
shutil.copy(user_file, user_backup_file)
except:
@ -388,7 +394,6 @@ class Script(scripts.Script):
outline = ''
if found:
csvtxt = '\n'.join(lines) + '\n'
import shutil
try:
shutil.copy(user_file, user_backup_file)
except:
@ -434,7 +439,6 @@ class Script(scripts.Script):
pass
lines.append(insert_line)
csvtxt = '\n'.join(lines) + '\n'
import shutil
try:
shutil.copy(user_file, user_backup_file)
except: