Merge pull request #259 from hinablue/main

Rewrite split_exif_data for more complex infotext
pull/265/head
AlUlkesh 2024-08-15 08:49:29 +02:00 committed by GitHub
commit 2b84c0fbc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 94 additions and 44 deletions

View File

@ -2,6 +2,7 @@ import hashlib
import json
import os
import sqlite3
import re
from shutil import copy2
from modules import scripts, shared
from tempfile import gettempdir
@ -58,7 +59,7 @@ def create_filehash(cursor):
''')
cursor.execute('''
CREATE TRIGGER filehash_tr
CREATE TRIGGER filehash_tr
AFTER UPDATE ON filehash
BEGIN
UPDATE filehash SET updated = CURRENT_TIMESTAMP WHERE file = OLD.file;
@ -95,7 +96,7 @@ def create_db(cursor):
''')
cursor.execute('''
CREATE TRIGGER path_recorder_tr
CREATE TRIGGER path_recorder_tr
AFTER UPDATE ON path_recorder
BEGIN
UPDATE path_recorder SET updated = CURRENT_TIMESTAMP WHERE path = OLD.path;
@ -118,7 +119,7 @@ def create_db(cursor):
''')
cursor.execute('''
CREATE TRIGGER exif_data_tr
CREATE TRIGGER exif_data_tr
AFTER UPDATE ON exif_data
BEGIN
UPDATE exif_data SET updated = CURRENT_TIMESTAMP WHERE file = OLD.file AND key = OLD.key;
@ -140,7 +141,7 @@ def create_db(cursor):
''')
cursor.execute('''
CREATE TRIGGER ranking_tr
CREATE TRIGGER ranking_tr
AFTER UPDATE ON ranking
BEGIN
UPDATE ranking SET updated = CURRENT_TIMESTAMP WHERE file = OLD.file;
@ -185,6 +186,57 @@ def split_exif_data(info):
negative_prompt = "0"
key_values = "0: 0"
key_value_pairs = []
def parse_value_pairs(kv_str, key_prefix=''):
# Regular expression pattern to match key-value pairs, including multiline prompts
pattern = r'((?:\w+ )?(?:Prompt|Negative Prompt)|[^:]+):\s*((?:[^,]+(?:,(?![^:]+:))?)+)'
# Find all matches
matches = re.findall(pattern, kv_str, re.IGNORECASE | re.DOTALL)
result = {}
current_prompt = None
def process_prompt(key, value, current_prompt):
if current_prompt is None:
result[key] = value
current_prompt = key
else:
pk_values = [v.strip() for v in key.split(',') if v.strip()]
result[current_prompt] += f",{','.join(pk_values[:-1])}"
current_prompt = pk_values[-1]
result[current_prompt] = ','.join([v.strip() for v in value.split(',') if v.strip()])
return current_prompt
def process_regular_key(key, value, current_prompt):
values = [v.strip() for v in value.split(',') if v.strip()]
if current_prompt is not None:
pk_values = [v.strip() for v in key.split(',') if v.strip()]
result[current_prompt] += f",{','.join(pk_values[:-1])}"
current_prompt = None
key = pk_values[-1]
result[key] = values[0] if len(values) == 1 else ','.join(values)
return current_prompt
for key, value in matches:
key = key.strip(' ,')
value = value.strip()
if "prompt" in key.lower() or "prompt" in value.lower():
current_prompt = process_prompt(key, value, current_prompt)
else:
current_prompt = process_regular_key(key, value, current_prompt)
# Print the resulting key-value pairs
for key, value in result.items():
value = value.strip(' ,')
if value.startswith('"') and value.endswith('"'):
value = value[1:-1]
parse_value_pairs(value, f"{key_prefix} - {key}" if key_prefix != '' else key)
key_value_pairs.append((f"{key_prefix} - {key}" if key_prefix != '' else key, value))
if info != "0":
info_list = info.split("\n")
prompt = ""
@ -202,21 +254,19 @@ def split_exif_data(info):
# multiline prompts
prompt = f"{prompt}\n{info_item}"
if key_values != "":
key_value = ""
quote_open = False
for char in key_values + ",":
key_value += char
if char == '"':
quote_open = not quote_open
if char == "," and not quote_open:
try:
k, v = key_value.strip(" ,").split(": ")
except ValueError:
k = key_value.strip(" ,").split(": ")[0]
v = ""
key_value_pairs.append((k, v))
key_value = ""
pattern = r'(\w+(?:\s+\w+)*?):\s*((?:"[^"]*"|[^,])+)(?:,\s*|$)'
matches = re.findall(pattern, key_values)
result = {key.strip(): value.strip() for key, value in matches}
# Save resulting key-value pairs
for key, value in result.items():
value = value.strip(' ,')
if value.startswith('"') and value.endswith('"'):
value = value[1:-1]
parse_value_pairs(value, key)
key_value_pairs.append((key, value))
return prompt, negative_prompt, key_value_pairs
def update_exif_data(cursor, file, info):
@ -243,7 +293,7 @@ def update_exif_data(cursor, file, info):
INSERT INTO exif_data (file, key, value)
VALUES (?, ?, ?)
''', (file, "negative_prompt", negative_prompt))
for (key, value) in key_value_pairs:
try:
cursor.execute('''
@ -252,18 +302,18 @@ def update_exif_data(cursor, file, info):
''', (file, key, value))
except sqlite3.IntegrityError:
pass
return
def migrate_exif_data(cursor):
if os.path.exists(exif_cache_file):
with open(exif_cache_file, 'r') as file:
exif_cache = json.load(file)
for file, info in exif_cache.items():
file = os.path.realpath(file)
update_exif_data(cursor, file, info)
return
def migrate_ranking(cursor):
@ -290,13 +340,13 @@ def get_hash(file):
hash = hashlib.sha512(image.tobytes()).hexdigest()
image.close()
return hash
def migrate_filehash(cursor, version):
if version <= "4":
create_filehash(cursor)
cursor.execute('''
SELECT file
FROM ranking
@ -323,7 +373,7 @@ def update_db_data(cursor, key, value):
INTO db_data (key, value)
VALUES (?, ?)
''', (key, value))
return
def get_version():
@ -334,7 +384,7 @@ def get_version():
WHERE key = 'version'
''',)
db_version = cursor.fetchone()
return db_version
def get_last_default_tab():
@ -345,7 +395,7 @@ def get_last_default_tab():
WHERE key = 'last_default_tab'
''',)
last_default_tab = cursor.fetchone()
return last_default_tab
def migrate_path_recorder_dirs(cursor):
@ -420,7 +470,7 @@ def migrate_ranking_dirs(cursor, db_version):
''')
cursor.execute('''
SELECT file, ranking
SELECT file, ranking
FROM ranking
''')
for (filepath, ranking) in cursor.fetchall():
@ -468,7 +518,7 @@ def check():
migrate_filehash(cursor, str(version))
print("Image Browser: Database created")
db_version = get_version()
with transaction() as cursor:
if db_version[0] <= "2":
# version 1 database had mixed path notations, changed them all to abspath
@ -486,7 +536,7 @@ def check():
update_db_data(cursor, "version", version)
print(f"Image Browser: Database upgraded from version {db_version[0]} to version {version}")
return version
def load_path_recorder():
@ -512,7 +562,7 @@ def select_ranking(file):
return_ranking = "None"
else:
(return_ranking,) = ranking_value
return return_ranking
def update_ranking(file, ranking):
@ -521,7 +571,7 @@ def update_ranking(file, ranking):
if ranking == "None":
cursor.execute('''
DELETE FROM ranking
WHERE file = ?
WHERE file = ?
''', (file,))
else:
cursor.execute('''
@ -529,14 +579,14 @@ def update_ranking(file, ranking):
INTO ranking (file, name, ranking)
VALUES (?, ?, ?)
''', (file, name, ranking))
hash = get_hash(file)
cursor.execute('''
INSERT OR REPLACE
INTO filehash (file, hash)
VALUES (?, ?)
''', (file, hash))
return
def update_path_recorder(path, depth, path_display):
@ -546,7 +596,7 @@ def update_path_recorder(path, depth, path_display):
INTO path_recorder (path, depth, path_display)
VALUES (?, ?, ?)
''', (path, depth, path_display))
return
def update_path_recorder(path, depth, path_display):
@ -556,7 +606,7 @@ def update_path_recorder(path, depth, path_display):
INTO path_recorder (path, depth, path_display)
VALUES (?, ?, ?)
''', (path, depth, path_display))
return
def delete_path_recorder(path):
@ -565,7 +615,7 @@ def delete_path_recorder(path):
DELETE FROM path_recorder
WHERE path = ?
''', (path,))
return
def update_path_recorder_mult(cursor, update_from, update_to):
@ -636,7 +686,7 @@ def get_ranking_by_name(cursor, name):
cursor.execute('''
SELECT hash
FROM filehash
WHERE file = ?
WHERE file = ?
''', (file,))
hash_value = cursor.fetchone()
else:
@ -650,7 +700,7 @@ def insert_ranking(cursor, file, ranking, hash):
INSERT INTO ranking (file, name, ranking)
VALUES (?, ?, ?)
''', (file, name, ranking))
cursor.execute('''
INSERT OR REPLACE
INTO filehash (file, hash)
@ -762,7 +812,7 @@ def get_exif_dirs():
def fill_work_files(cursor, fileinfos):
filenames = [x[0] for x in fileinfos]
cursor.execute('''
DELETE
FROM work_files
@ -847,13 +897,13 @@ def filter_ranking(cursor, fileinfos, ranking_filter, ranking_filter_min_num, ra
''')
rows = cursor.fetchall()
fileinfos_dict = {pair[0]: pair[1] for pair in fileinfos}
fileinfos_new = []
for (file,) in rows:
if fileinfos_dict.get(file) is not None:
fileinfos_new.append((file, fileinfos_dict[file]))
return fileinfos_new
def select_x_y(cursor, file):
@ -874,4 +924,4 @@ def select_x_y(cursor, file):
x = parts[0]
y = parts[1]
return x, y
return x, y