diff --git a/scripts/wib/wib_db.py b/scripts/wib/wib_db.py index 350aead..8809268 100644 --- a/scripts/wib/wib_db.py +++ b/scripts/wib/wib_db.py @@ -2,6 +2,7 @@ import hashlib import json import os import sqlite3 +import re from shutil import copy2 from modules import scripts, shared from tempfile import gettempdir @@ -58,7 +59,7 @@ def create_filehash(cursor): ''') cursor.execute(''' - CREATE TRIGGER filehash_tr + CREATE TRIGGER filehash_tr AFTER UPDATE ON filehash BEGIN UPDATE filehash SET updated = CURRENT_TIMESTAMP WHERE file = OLD.file; @@ -95,7 +96,7 @@ def create_db(cursor): ''') cursor.execute(''' - CREATE TRIGGER path_recorder_tr + CREATE TRIGGER path_recorder_tr AFTER UPDATE ON path_recorder BEGIN UPDATE path_recorder SET updated = CURRENT_TIMESTAMP WHERE path = OLD.path; @@ -118,7 +119,7 @@ def create_db(cursor): ''') cursor.execute(''' - CREATE TRIGGER exif_data_tr + CREATE TRIGGER exif_data_tr AFTER UPDATE ON exif_data BEGIN UPDATE exif_data SET updated = CURRENT_TIMESTAMP WHERE file = OLD.file AND key = OLD.key; @@ -140,7 +141,7 @@ def create_db(cursor): ''') cursor.execute(''' - CREATE TRIGGER ranking_tr + CREATE TRIGGER ranking_tr AFTER UPDATE ON ranking BEGIN UPDATE ranking SET updated = CURRENT_TIMESTAMP WHERE file = OLD.file; @@ -185,6 +186,57 @@ def split_exif_data(info): negative_prompt = "0" key_values = "0: 0" key_value_pairs = [] + + def parse_value_pairs(kv_str, key_prefix=''): + # Regular expression pattern to match key-value pairs, including multiline prompts + pattern = r'((?:\w+ )?(?:Prompt|Negative Prompt)|[^:]+):\s*((?:[^,]+(?:,(?![^:]+:))?)+)' + + # Find all matches + matches = re.findall(pattern, kv_str, re.IGNORECASE | re.DOTALL) + result = {} + current_prompt = None + + def process_prompt(key, value, current_prompt): + if current_prompt is None: + result[key] = value + current_prompt = key + else: + pk_values = [v.strip() for v in key.split(',') if v.strip()] + result[current_prompt] += f",{','.join(pk_values[:-1])}" + current_prompt = pk_values[-1] + result[current_prompt] = ','.join([v.strip() for v in value.split(',') if v.strip()]) + + return current_prompt + + def process_regular_key(key, value, current_prompt): + values = [v.strip() for v in value.split(',') if v.strip()] + if current_prompt is not None: + pk_values = [v.strip() for v in key.split(',') if v.strip()] + result[current_prompt] += f",{','.join(pk_values[:-1])}" + current_prompt = None + key = pk_values[-1] + result[key] = values[0] if len(values) == 1 else ','.join(values) + + return current_prompt + + for key, value in matches: + key = key.strip(' ,') + value = value.strip() + + if "prompt" in key.lower() or "prompt" in value.lower(): + current_prompt = process_prompt(key, value, current_prompt) + else: + current_prompt = process_regular_key(key, value, current_prompt) + + # Print the resulting key-value pairs + for key, value in result.items(): + value = value.strip(' ,') + if value.startswith('"') and value.endswith('"'): + value = value[1:-1] + parse_value_pairs(value, f"{key_prefix} - {key}" if key_prefix != '' else key) + + key_value_pairs.append((f"{key_prefix} - {key}" if key_prefix != '' else key, value)) + if info != "0": info_list = info.split("\n") prompt = "" @@ -202,21 +254,19 @@ def split_exif_data(info): # multiline prompts prompt = f"{prompt}\n{info_item}" if key_values != "": - key_value = "" - quote_open = False - for char in key_values + ",": - key_value += char - if char == '"': - quote_open = not quote_open - if char == "," and not quote_open: - try: - k, v = key_value.strip(" ,").split(": ") - except ValueError: - k = key_value.strip(" ,").split(": ")[0] - v = "" - key_value_pairs.append((k, v)) - key_value = "" - + pattern = r'(\w+(?:\s+\w+)*?):\s*((?:"[^"]*"|[^,])+)(?:,\s*|$)' + matches = re.findall(pattern, key_values) + result = {key.strip(): value.strip() for key, value in matches} + + # Save resulting key-value pairs + for key, value in result.items(): + value = value.strip(' ,') + if value.startswith('"') and value.endswith('"'): + value = value[1:-1] + parse_value_pairs(value, key) + + key_value_pairs.append((key, value)) + return prompt, negative_prompt, key_value_pairs def update_exif_data(cursor, file, info): @@ -243,7 +293,7 @@ def update_exif_data(cursor, file, info): INSERT INTO exif_data (file, key, value) VALUES (?, ?, ?) ''', (file, "negative_prompt", negative_prompt)) - + for (key, value) in key_value_pairs: try: cursor.execute(''' @@ -252,18 +302,18 @@ def update_exif_data(cursor, file, info): ''', (file, key, value)) except sqlite3.IntegrityError: pass - + return def migrate_exif_data(cursor): if os.path.exists(exif_cache_file): with open(exif_cache_file, 'r') as file: exif_cache = json.load(file) - + for file, info in exif_cache.items(): file = os.path.realpath(file) update_exif_data(cursor, file, info) - + return def migrate_ranking(cursor): @@ -290,13 +340,13 @@ def get_hash(file): hash = hashlib.sha512(image.tobytes()).hexdigest() image.close() - + return hash def migrate_filehash(cursor, version): if version <= "4": create_filehash(cursor) - + cursor.execute(''' SELECT file FROM ranking @@ -323,7 +373,7 @@ def update_db_data(cursor, key, value): INTO db_data (key, value) VALUES (?, ?) ''', (key, value)) - + return def get_version(): @@ -334,7 +384,7 @@ def get_version(): WHERE key = 'version' ''',) db_version = cursor.fetchone() - + return db_version def get_last_default_tab(): @@ -345,7 +395,7 @@ def get_last_default_tab(): WHERE key = 'last_default_tab' ''',) last_default_tab = cursor.fetchone() - + return last_default_tab def migrate_path_recorder_dirs(cursor): @@ -420,7 +470,7 @@ def migrate_ranking_dirs(cursor, db_version): ''') cursor.execute(''' - SELECT file, ranking + SELECT file, ranking FROM ranking ''') for (filepath, ranking) in cursor.fetchall(): @@ -468,7 +518,7 @@ def check(): migrate_filehash(cursor, str(version)) print("Image Browser: Database created") db_version = get_version() - + with transaction() as cursor: if db_version[0] <= "2": # version 1 database had mixed path notations, changed them all to abspath @@ -486,7 +536,7 @@ def check(): update_db_data(cursor, "version", version) print(f"Image Browser: Database upgraded from version {db_version[0]} to version {version}") - + return version def load_path_recorder(): @@ -512,7 +562,7 @@ def select_ranking(file): return_ranking = "None" else: (return_ranking,) = ranking_value - + return return_ranking def update_ranking(file, ranking): @@ -521,7 +571,7 @@ def update_ranking(file, ranking): if ranking == "None": cursor.execute(''' DELETE FROM ranking - WHERE file = ? + WHERE file = ? ''', (file,)) else: cursor.execute(''' @@ -529,14 +579,14 @@ def update_ranking(file, ranking): INTO ranking (file, name, ranking) VALUES (?, ?, ?) ''', (file, name, ranking)) - + hash = get_hash(file) cursor.execute(''' INSERT OR REPLACE INTO filehash (file, hash) VALUES (?, ?) ''', (file, hash)) - + return def update_path_recorder(path, depth, path_display): @@ -546,7 +596,7 @@ def update_path_recorder(path, depth, path_display): INTO path_recorder (path, depth, path_display) VALUES (?, ?, ?) ''', (path, depth, path_display)) - + return def update_path_recorder(path, depth, path_display): @@ -556,7 +606,7 @@ def update_path_recorder(path, depth, path_display): INTO path_recorder (path, depth, path_display) VALUES (?, ?, ?) ''', (path, depth, path_display)) - + return def delete_path_recorder(path): @@ -565,7 +615,7 @@ def delete_path_recorder(path): DELETE FROM path_recorder WHERE path = ? ''', (path,)) - + return def update_path_recorder_mult(cursor, update_from, update_to): @@ -636,7 +686,7 @@ def get_ranking_by_name(cursor, name): cursor.execute(''' SELECT hash FROM filehash - WHERE file = ? + WHERE file = ? ''', (file,)) hash_value = cursor.fetchone() else: @@ -650,7 +700,7 @@ def insert_ranking(cursor, file, ranking, hash): INSERT INTO ranking (file, name, ranking) VALUES (?, ?, ?) ''', (file, name, ranking)) - + cursor.execute(''' INSERT OR REPLACE INTO filehash (file, hash) @@ -762,7 +812,7 @@ def get_exif_dirs(): def fill_work_files(cursor, fileinfos): filenames = [x[0] for x in fileinfos] - + cursor.execute(''' DELETE FROM work_files @@ -847,13 +897,13 @@ def filter_ranking(cursor, fileinfos, ranking_filter, ranking_filter_min_num, ra ''') rows = cursor.fetchall() - + fileinfos_dict = {pair[0]: pair[1] for pair in fileinfos} fileinfos_new = [] for (file,) in rows: if fileinfos_dict.get(file) is not None: fileinfos_new.append((file, fileinfos_dict[file])) - + return fileinfos_new def select_x_y(cursor, file): @@ -874,4 +924,4 @@ def select_x_y(cursor, file): x = parts[0] y = parts[1] - return x, y \ No newline at end of file + return x, y