diff --git a/scripts/wib/wib_db.py b/scripts/wib/wib_db.py index 4ba4263..f3023c8 100644 --- a/scripts/wib/wib_db.py +++ b/scripts/wib/wib_db.py @@ -4,7 +4,9 @@ import os import sqlite3 from shutil import copy2 from modules import scripts, shared +from tempfile import gettempdir from PIL import Image +from contextlib import contextmanager version = 7 @@ -14,7 +16,7 @@ exif_cache_file = os.path.join(scripts.basedir(), "exif_data.json") ranking_file = os.path.join(scripts.basedir(), "ranking.json") archive = os.path.join(scripts.basedir(), "archive") source_db_file = os.path.join(scripts.basedir(), "wib.sqlite3") -tmp_db_file = "/tmp/sd-images-browser.sqlite3" +tmp_db_file = os.path.join(gettempdir(), "sd-images-browser.sqlite3") db_file = source_db_file if getattr(shared.cmd_opts, "image_browser_tmp_db", False): @@ -32,6 +34,19 @@ np = "Negative prompt: " st = "Steps: " timeout = 30 +@contextmanager +def transaction(db = db_file): + conn = sqlite3.connect(db, timeout=timeout) + try: + conn.isolation_level = None + cursor = conn.cursor() + cursor.execute("BEGIN") + yield cursor + cursor.execute("COMMIT") + finally: + conn.close() + backup_tmp_db() + def create_filehash(cursor): cursor.execute(''' CREATE TABLE IF NOT EXISTS filehash ( @@ -307,8 +322,7 @@ def update_db_data(cursor, key, value): return def get_version(): - with sqlite3.connect(db_file, timeout=timeout) as conn: - cursor = conn.cursor() + with transaction() as cursor: cursor.execute(''' SELECT value FROM db_data @@ -319,8 +333,7 @@ def get_version(): return db_version def get_last_default_tab(): - with sqlite3.connect(db_file, timeout=timeout) as conn: - cursor = conn.cursor() + with transaction() as cursor: cursor.execute(''' SELECT value FROM db_data @@ -439,42 +452,40 @@ def migrate_ranking_dirs(cursor, db_version): def check(): if not os.path.exists(db_file): - conn, cursor = transaction_begin() print("Image Browser: Creating database") - create_db(cursor) - update_db_data(cursor, "version", version) - update_db_data(cursor, "last_default_tab", "Maintenance") - migrate_path_recorder(cursor) - migrate_exif_data(cursor) - migrate_ranking(cursor) - migrate_filehash(cursor, str(version)) - transaction_end(conn, cursor) + with transaction() as cursor: + create_db(cursor) + update_db_data(cursor, "version", version) + update_db_data(cursor, "last_default_tab", "Maintenance") + migrate_path_recorder(cursor) + migrate_exif_data(cursor) + migrate_ranking(cursor) + migrate_filehash(cursor, str(version)) print("Image Browser: Database created") db_version = get_version() - conn, cursor = transaction_begin() - if db_version[0] <= "2": - # version 1 database had mixed path notations, changed them all to abspath - # version 2 database still had mixed path notations, because of windows short name, changed them all to realpath - print(f"Image Browser: Upgrading database from version {db_version[0]} to version {version}") - migrate_path_recorder_dirs(cursor) - migrate_exif_data_dirs(cursor) - migrate_ranking_dirs(cursor, db_version[0]) - if db_version[0] <= "4": - migrate_filehash(cursor, db_version[0]) - if db_version[0] <= "5": - migrate_work_files(cursor) - if db_version[0] <= "6": - update_db_data(cursor, "last_default_tab", "Others") + + with transaction() as cursor: + if db_version[0] <= "2": + # version 1 database had mixed path notations, changed them all to abspath + # version 2 database still had mixed path notations, because of windows short name, changed them all to realpath + print(f"Image Browser: Upgrading database from version {db_version[0]} to version {version}") + migrate_path_recorder_dirs(cursor) + migrate_exif_data_dirs(cursor) + migrate_ranking_dirs(cursor, db_version[0]) + if db_version[0] <= "4": + migrate_filehash(cursor, db_version[0]) + if db_version[0] <= "5": + migrate_work_files(cursor) + if db_version[0] <= "6": + update_db_data(cursor, "last_default_tab", "Others") - update_db_data(cursor, "version", version) - print(f"Image Browser: Database upgraded from version {db_version[0]} to version {version}") - transaction_end(conn, cursor) + update_db_data(cursor, "version", version) + print(f"Image Browser: Database upgraded from version {db_version[0]} to version {version}") return version def load_path_recorder(): - with sqlite3.connect(db_file, timeout=timeout) as conn: - cursor = conn.cursor() + with transaction() as cursor: cursor.execute(''' SELECT path, depth, path_display FROM path_recorder @@ -484,8 +495,7 @@ def load_path_recorder(): return path_recorder def select_ranking(file): - with sqlite3.connect(db_file, timeout=timeout) as conn: - cursor = conn.cursor() + with transaction() as cursor: cursor.execute(''' SELECT ranking FROM ranking @@ -502,8 +512,7 @@ def select_ranking(file): def update_ranking(file, ranking): name = os.path.basename(file) - with sqlite3.connect(db_file, timeout=timeout) as conn: - cursor = conn.cursor() + with transaction() as cursor: if ranking == "None": cursor.execute(''' DELETE FROM ranking @@ -526,8 +535,7 @@ def update_ranking(file, ranking): return def update_path_recorder(path, depth, path_display): - with sqlite3.connect(db_file, timeout=timeout) as conn: - cursor = conn.cursor() + with transaction() as cursor: cursor.execute(''' INSERT OR REPLACE INTO path_recorder (path, depth, path_display) @@ -537,8 +545,7 @@ def update_path_recorder(path, depth, path_display): return def update_path_recorder(path, depth, path_display): - with sqlite3.connect(db_file, timeout=timeout) as conn: - cursor = conn.cursor() + with transaction() as cursor: cursor.execute(''' INSERT OR REPLACE INTO path_recorder (path, depth, path_display) @@ -548,8 +555,7 @@ def update_path_recorder(path, depth, path_display): return def delete_path_recorder(path): - with sqlite3.connect(db_file, timeout=timeout) as conn: - cursor = conn.cursor() + with transaction() as cursor: cursor.execute(''' DELETE FROM path_recorder WHERE path = ? @@ -664,19 +670,6 @@ def replace_ranking(cursor, file, alternate_file, hash): return -def transaction_begin(): - conn = sqlite3.connect(db_file, timeout=timeout) - conn.isolation_level = None - cursor = conn.cursor() - cursor.execute("BEGIN") - return conn, cursor - -def transaction_end(conn, cursor): - cursor.execute("COMMIT") - conn.close() - backup_tmp_db() - return - def update_exif_data_by_key(cursor, file, key, value): cursor.execute(''' INSERT OR REPLACE @@ -687,8 +680,7 @@ def update_exif_data_by_key(cursor, file, key, value): return def select_prompts(file): - with sqlite3.connect(db_file, timeout=timeout) as conn: - cursor = conn.cursor() + with transaction() as cursor: cursor.execute(''' SELECT key, value FROM exif_data @@ -709,8 +701,7 @@ def select_prompts(file): return prompt, neg_prompt def load_exif_data(exif_cache): - with sqlite3.connect(db_file, timeout=timeout) as conn: - cursor = conn.cursor() + with transaction() as cursor: cursor.execute(''' SELECT file, group_concat( case when key = 'prompt' or key = 'negative_prompt' then key || ': ' || value || '\n' @@ -735,8 +726,7 @@ def load_exif_data(exif_cache): return exif_cache def load_exif_data_by_key(cache, key1, key2): - with sqlite3.connect(db_file, timeout=timeout) as conn: - cursor = conn.cursor() + with transaction() as cursor: cursor.execute(''' SELECT file, value FROM exif_data @@ -750,8 +740,7 @@ def load_exif_data_by_key(cache, key1, key2): return cache def get_exif_dirs(): - with sqlite3.connect(db_file, timeout=timeout) as conn: - cursor = conn.cursor() + with transaction() as cursor: cursor.execute(''' SELECT file FROM exif_data