Merge pull request #259 from hinablue/main

Rewrite split_exif_data for more complex infotext
pull/265/head
AlUlkesh 2024-08-15 08:49:29 +02:00 committed by GitHub
commit 2b84c0fbc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 94 additions and 44 deletions

View File

@ -2,6 +2,7 @@ import hashlib
import json import json
import os import os
import sqlite3 import sqlite3
import re
from shutil import copy2 from shutil import copy2
from modules import scripts, shared from modules import scripts, shared
from tempfile import gettempdir from tempfile import gettempdir
@ -58,7 +59,7 @@ def create_filehash(cursor):
''') ''')
cursor.execute(''' cursor.execute('''
CREATE TRIGGER filehash_tr CREATE TRIGGER filehash_tr
AFTER UPDATE ON filehash AFTER UPDATE ON filehash
BEGIN BEGIN
UPDATE filehash SET updated = CURRENT_TIMESTAMP WHERE file = OLD.file; UPDATE filehash SET updated = CURRENT_TIMESTAMP WHERE file = OLD.file;
@ -95,7 +96,7 @@ def create_db(cursor):
''') ''')
cursor.execute(''' cursor.execute('''
CREATE TRIGGER path_recorder_tr CREATE TRIGGER path_recorder_tr
AFTER UPDATE ON path_recorder AFTER UPDATE ON path_recorder
BEGIN BEGIN
UPDATE path_recorder SET updated = CURRENT_TIMESTAMP WHERE path = OLD.path; UPDATE path_recorder SET updated = CURRENT_TIMESTAMP WHERE path = OLD.path;
@ -118,7 +119,7 @@ def create_db(cursor):
''') ''')
cursor.execute(''' cursor.execute('''
CREATE TRIGGER exif_data_tr CREATE TRIGGER exif_data_tr
AFTER UPDATE ON exif_data AFTER UPDATE ON exif_data
BEGIN BEGIN
UPDATE exif_data SET updated = CURRENT_TIMESTAMP WHERE file = OLD.file AND key = OLD.key; UPDATE exif_data SET updated = CURRENT_TIMESTAMP WHERE file = OLD.file AND key = OLD.key;
@ -140,7 +141,7 @@ def create_db(cursor):
''') ''')
cursor.execute(''' cursor.execute('''
CREATE TRIGGER ranking_tr CREATE TRIGGER ranking_tr
AFTER UPDATE ON ranking AFTER UPDATE ON ranking
BEGIN BEGIN
UPDATE ranking SET updated = CURRENT_TIMESTAMP WHERE file = OLD.file; UPDATE ranking SET updated = CURRENT_TIMESTAMP WHERE file = OLD.file;
@ -185,6 +186,57 @@ def split_exif_data(info):
negative_prompt = "0" negative_prompt = "0"
key_values = "0: 0" key_values = "0: 0"
key_value_pairs = [] key_value_pairs = []
def parse_value_pairs(kv_str, key_prefix=''):
# Regular expression pattern to match key-value pairs, including multiline prompts
pattern = r'((?:\w+ )?(?:Prompt|Negative Prompt)|[^:]+):\s*((?:[^,]+(?:,(?![^:]+:))?)+)'
# Find all matches
matches = re.findall(pattern, kv_str, re.IGNORECASE | re.DOTALL)
result = {}
current_prompt = None
def process_prompt(key, value, current_prompt):
if current_prompt is None:
result[key] = value
current_prompt = key
else:
pk_values = [v.strip() for v in key.split(',') if v.strip()]
result[current_prompt] += f",{','.join(pk_values[:-1])}"
current_prompt = pk_values[-1]
result[current_prompt] = ','.join([v.strip() for v in value.split(',') if v.strip()])
return current_prompt
def process_regular_key(key, value, current_prompt):
values = [v.strip() for v in value.split(',') if v.strip()]
if current_prompt is not None:
pk_values = [v.strip() for v in key.split(',') if v.strip()]
result[current_prompt] += f",{','.join(pk_values[:-1])}"
current_prompt = None
key = pk_values[-1]
result[key] = values[0] if len(values) == 1 else ','.join(values)
return current_prompt
for key, value in matches:
key = key.strip(' ,')
value = value.strip()
if "prompt" in key.lower() or "prompt" in value.lower():
current_prompt = process_prompt(key, value, current_prompt)
else:
current_prompt = process_regular_key(key, value, current_prompt)
# Print the resulting key-value pairs
for key, value in result.items():
value = value.strip(' ,')
if value.startswith('"') and value.endswith('"'):
value = value[1:-1]
parse_value_pairs(value, f"{key_prefix} - {key}" if key_prefix != '' else key)
key_value_pairs.append((f"{key_prefix} - {key}" if key_prefix != '' else key, value))
if info != "0": if info != "0":
info_list = info.split("\n") info_list = info.split("\n")
prompt = "" prompt = ""
@ -202,21 +254,19 @@ def split_exif_data(info):
# multiline prompts # multiline prompts
prompt = f"{prompt}\n{info_item}" prompt = f"{prompt}\n{info_item}"
if key_values != "": if key_values != "":
key_value = "" pattern = r'(\w+(?:\s+\w+)*?):\s*((?:"[^"]*"|[^,])+)(?:,\s*|$)'
quote_open = False matches = re.findall(pattern, key_values)
for char in key_values + ",": result = {key.strip(): value.strip() for key, value in matches}
key_value += char
if char == '"': # Save resulting key-value pairs
quote_open = not quote_open for key, value in result.items():
if char == "," and not quote_open: value = value.strip(' ,')
try: if value.startswith('"') and value.endswith('"'):
k, v = key_value.strip(" ,").split(": ") value = value[1:-1]
except ValueError: parse_value_pairs(value, key)
k = key_value.strip(" ,").split(": ")[0]
v = "" key_value_pairs.append((key, value))
key_value_pairs.append((k, v))
key_value = ""
return prompt, negative_prompt, key_value_pairs return prompt, negative_prompt, key_value_pairs
def update_exif_data(cursor, file, info): def update_exif_data(cursor, file, info):
@ -243,7 +293,7 @@ def update_exif_data(cursor, file, info):
INSERT INTO exif_data (file, key, value) INSERT INTO exif_data (file, key, value)
VALUES (?, ?, ?) VALUES (?, ?, ?)
''', (file, "negative_prompt", negative_prompt)) ''', (file, "negative_prompt", negative_prompt))
for (key, value) in key_value_pairs: for (key, value) in key_value_pairs:
try: try:
cursor.execute(''' cursor.execute('''
@ -252,18 +302,18 @@ def update_exif_data(cursor, file, info):
''', (file, key, value)) ''', (file, key, value))
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
pass pass
return return
def migrate_exif_data(cursor): def migrate_exif_data(cursor):
if os.path.exists(exif_cache_file): if os.path.exists(exif_cache_file):
with open(exif_cache_file, 'r') as file: with open(exif_cache_file, 'r') as file:
exif_cache = json.load(file) exif_cache = json.load(file)
for file, info in exif_cache.items(): for file, info in exif_cache.items():
file = os.path.realpath(file) file = os.path.realpath(file)
update_exif_data(cursor, file, info) update_exif_data(cursor, file, info)
return return
def migrate_ranking(cursor): def migrate_ranking(cursor):
@ -290,13 +340,13 @@ def get_hash(file):
hash = hashlib.sha512(image.tobytes()).hexdigest() hash = hashlib.sha512(image.tobytes()).hexdigest()
image.close() image.close()
return hash return hash
def migrate_filehash(cursor, version): def migrate_filehash(cursor, version):
if version <= "4": if version <= "4":
create_filehash(cursor) create_filehash(cursor)
cursor.execute(''' cursor.execute('''
SELECT file SELECT file
FROM ranking FROM ranking
@ -323,7 +373,7 @@ def update_db_data(cursor, key, value):
INTO db_data (key, value) INTO db_data (key, value)
VALUES (?, ?) VALUES (?, ?)
''', (key, value)) ''', (key, value))
return return
def get_version(): def get_version():
@ -334,7 +384,7 @@ def get_version():
WHERE key = 'version' WHERE key = 'version'
''',) ''',)
db_version = cursor.fetchone() db_version = cursor.fetchone()
return db_version return db_version
def get_last_default_tab(): def get_last_default_tab():
@ -345,7 +395,7 @@ def get_last_default_tab():
WHERE key = 'last_default_tab' WHERE key = 'last_default_tab'
''',) ''',)
last_default_tab = cursor.fetchone() last_default_tab = cursor.fetchone()
return last_default_tab return last_default_tab
def migrate_path_recorder_dirs(cursor): def migrate_path_recorder_dirs(cursor):
@ -420,7 +470,7 @@ def migrate_ranking_dirs(cursor, db_version):
''') ''')
cursor.execute(''' cursor.execute('''
SELECT file, ranking SELECT file, ranking
FROM ranking FROM ranking
''') ''')
for (filepath, ranking) in cursor.fetchall(): for (filepath, ranking) in cursor.fetchall():
@ -468,7 +518,7 @@ def check():
migrate_filehash(cursor, str(version)) migrate_filehash(cursor, str(version))
print("Image Browser: Database created") print("Image Browser: Database created")
db_version = get_version() db_version = get_version()
with transaction() as cursor: with transaction() as cursor:
if db_version[0] <= "2": if db_version[0] <= "2":
# version 1 database had mixed path notations, changed them all to abspath # version 1 database had mixed path notations, changed them all to abspath
@ -486,7 +536,7 @@ def check():
update_db_data(cursor, "version", version) update_db_data(cursor, "version", version)
print(f"Image Browser: Database upgraded from version {db_version[0]} to version {version}") print(f"Image Browser: Database upgraded from version {db_version[0]} to version {version}")
return version return version
def load_path_recorder(): def load_path_recorder():
@ -512,7 +562,7 @@ def select_ranking(file):
return_ranking = "None" return_ranking = "None"
else: else:
(return_ranking,) = ranking_value (return_ranking,) = ranking_value
return return_ranking return return_ranking
def update_ranking(file, ranking): def update_ranking(file, ranking):
@ -521,7 +571,7 @@ def update_ranking(file, ranking):
if ranking == "None": if ranking == "None":
cursor.execute(''' cursor.execute('''
DELETE FROM ranking DELETE FROM ranking
WHERE file = ? WHERE file = ?
''', (file,)) ''', (file,))
else: else:
cursor.execute(''' cursor.execute('''
@ -529,14 +579,14 @@ def update_ranking(file, ranking):
INTO ranking (file, name, ranking) INTO ranking (file, name, ranking)
VALUES (?, ?, ?) VALUES (?, ?, ?)
''', (file, name, ranking)) ''', (file, name, ranking))
hash = get_hash(file) hash = get_hash(file)
cursor.execute(''' cursor.execute('''
INSERT OR REPLACE INSERT OR REPLACE
INTO filehash (file, hash) INTO filehash (file, hash)
VALUES (?, ?) VALUES (?, ?)
''', (file, hash)) ''', (file, hash))
return return
def update_path_recorder(path, depth, path_display): def update_path_recorder(path, depth, path_display):
@ -546,7 +596,7 @@ def update_path_recorder(path, depth, path_display):
INTO path_recorder (path, depth, path_display) INTO path_recorder (path, depth, path_display)
VALUES (?, ?, ?) VALUES (?, ?, ?)
''', (path, depth, path_display)) ''', (path, depth, path_display))
return return
def update_path_recorder(path, depth, path_display): def update_path_recorder(path, depth, path_display):
@ -556,7 +606,7 @@ def update_path_recorder(path, depth, path_display):
INTO path_recorder (path, depth, path_display) INTO path_recorder (path, depth, path_display)
VALUES (?, ?, ?) VALUES (?, ?, ?)
''', (path, depth, path_display)) ''', (path, depth, path_display))
return return
def delete_path_recorder(path): def delete_path_recorder(path):
@ -565,7 +615,7 @@ def delete_path_recorder(path):
DELETE FROM path_recorder DELETE FROM path_recorder
WHERE path = ? WHERE path = ?
''', (path,)) ''', (path,))
return return
def update_path_recorder_mult(cursor, update_from, update_to): def update_path_recorder_mult(cursor, update_from, update_to):
@ -636,7 +686,7 @@ def get_ranking_by_name(cursor, name):
cursor.execute(''' cursor.execute('''
SELECT hash SELECT hash
FROM filehash FROM filehash
WHERE file = ? WHERE file = ?
''', (file,)) ''', (file,))
hash_value = cursor.fetchone() hash_value = cursor.fetchone()
else: else:
@ -650,7 +700,7 @@ def insert_ranking(cursor, file, ranking, hash):
INSERT INTO ranking (file, name, ranking) INSERT INTO ranking (file, name, ranking)
VALUES (?, ?, ?) VALUES (?, ?, ?)
''', (file, name, ranking)) ''', (file, name, ranking))
cursor.execute(''' cursor.execute('''
INSERT OR REPLACE INSERT OR REPLACE
INTO filehash (file, hash) INTO filehash (file, hash)
@ -762,7 +812,7 @@ def get_exif_dirs():
def fill_work_files(cursor, fileinfos): def fill_work_files(cursor, fileinfos):
filenames = [x[0] for x in fileinfos] filenames = [x[0] for x in fileinfos]
cursor.execute(''' cursor.execute('''
DELETE DELETE
FROM work_files FROM work_files
@ -847,13 +897,13 @@ def filter_ranking(cursor, fileinfos, ranking_filter, ranking_filter_min_num, ra
''') ''')
rows = cursor.fetchall() rows = cursor.fetchall()
fileinfos_dict = {pair[0]: pair[1] for pair in fileinfos} fileinfos_dict = {pair[0]: pair[1] for pair in fileinfos}
fileinfos_new = [] fileinfos_new = []
for (file,) in rows: for (file,) in rows:
if fileinfos_dict.get(file) is not None: if fileinfos_dict.get(file) is not None:
fileinfos_new.append((file, fileinfos_dict[file])) fileinfos_new.append((file, fileinfos_dict[file]))
return fileinfos_new return fileinfos_new
def select_x_y(cursor, file): def select_x_y(cursor, file):
@ -874,4 +924,4 @@ def select_x_y(cursor, file):
x = parts[0] x = parts[0]
y = parts[1] y = parts[1]
return x, y return x, y