import os import sqlite3 import threading from typing import List from modules import shared from scripts.mo.data.storage import Storage from scripts.mo.environment import env, logger from scripts.mo.models import Record, ModelType _DB_FILE = 'database.sqlite' _DB_VERSION = 6 _DB_TIMEOUT = 30 def map_row_to_record(row) -> Record: return Record( id_=row[0], name=row[1], model_type=ModelType.by_value(row[2]), download_url=row[3], url=row[4], download_path=row[5], download_filename=row[6], preview_url=row[7], description=row[8], positive_prompts=row[9], negative_prompts=row[10], sha256_hash=row[11], md5_hash=row[12], created_at=row[13], groups=row[14].split(',') if row[14] else [], subdir=row[15], location=row[16], weight=row[17] ) class SQLiteStorage(Storage): def __init__(self): self.local = threading.local() self._initialize() def _connection(self): if not hasattr(self.local, "connection"): mo_database_dir = getattr(shared.cmd_opts, "mo_database_dir") database_dir = mo_database_dir if mo_database_dir is not None else env.script_dir db_file_path = os.path.join(database_dir, _DB_FILE) self.local.connection = sqlite3.connect(db_file_path, _DB_TIMEOUT) return self.local.connection def _initialize(self): cursor = self._connection().cursor() cursor.execute('''CREATE TABLE IF NOT EXISTS Record (id INTEGER PRIMARY KEY, _name TEXT, model_type TEXT, download_url TEXT, url TEXT DEFAULT '', download_path TEXT DEFAULT '', download_filename TEXT DEFAULT '', preview_url TEXT DEFAULT '', description TEXT DEFAULT '', positive_prompts TEXT DEFAULT '', negative_prompts TEXT DEFAULT '', sha256_hash TEXT DEFAULT '', md5_hash TEXT DEFAULT '', created_at INTEGER DEFAULT 0, groups TEXT DEFAULT '', subdir TEXT DEFAULT '', location TEXT DEFAULT '', weight REAL DEFAULT 1) ''') cursor.execute(f'''CREATE TABLE IF NOT EXISTS Version (version INTEGER DEFAULT {_DB_VERSION})''') self._connection().commit() self._check_database_version() def _check_database_version(self): cursor = self._connection().cursor() cursor.execute('SELECT * FROM Version ', ) row = cursor.fetchone() if row is None: cursor.execute(f'INSERT INTO Version VALUES ({_DB_VERSION})') self._connection().commit() version = _DB_VERSION if row is None else row[0] if version != _DB_VERSION: self._run_migration(version) def _run_migration(self, current_version): for ver in range(current_version, _DB_VERSION): if ver == 1: self._migrate_1_to_2() elif ver == 2: self._migrate_2_to_3() elif ver == 3: self._migrate_3_to_4() elif ver == 4: self._migrate_4_to_5() elif ver == 5: self._migrage_5_to_6() else: raise Exception(f'Missing SQLite migration from {ver} to {_DB_VERSION}') def _migrate_1_to_2(self): cursor = self._connection().cursor() cursor.execute('ALTER TABLE Record ADD COLUMN created_at INTEGER DEFAULT 0;') cursor.execute("DELETE FROM Version") cursor.execute('INSERT INTO Version VALUES (2)') self._connection().commit() def _migrate_2_to_3(self): cursor = self._connection().cursor() cursor.execute("ALTER TABLE Record ADD COLUMN groups TEXT DEFAULT '';") cursor.execute("DELETE FROM Version") cursor.execute('INSERT INTO Version VALUES (3)') self._connection().commit() def _migrate_3_to_4(self): cursor = self._connection().cursor() cursor.execute("ALTER TABLE Record RENAME COLUMN model_hash TO sha256_hash;") cursor.execute("ALTER TABLE Record ADD COLUMN subdir TEXT DEFAULT '';") cursor.execute("DELETE FROM Version") cursor.execute('INSERT INTO Version VALUES (4)') self._connection().commit() def _migrate_4_to_5(self): cursor = self._connection().cursor() cursor.execute("ALTER TABLE Record ADD COLUMN location TEXT DEFAULT '';") cursor.execute("DELETE FROM Version") cursor.execute('INSERT INTO Version VALUES (5)') self._connection().commit() def _migrage_5_to_6(self): cursor = self._connection().cursor() cursor.execute("ALTER TABLE Record ADD COLUMN weight REAL DEFAULT 1;") cursor.execute("DELETE FROM Version") cursor.execute('INSERT INTO Version VALUES (6)') self._connection().commit() def get_all_records(self) -> List: cursor = self._connection().cursor() cursor.execute('SELECT * FROM Record') rows = cursor.fetchall() result = [] for row in rows: result.append(map_row_to_record(row)) return result def query_records(self, name_query: str = None, groups=None, model_types=None, show_downloaded=True, show_not_downloaded=True) -> List: query = 'SELECT * FROM Record' is_where_appended = False append_and = False if name_query is not None and name_query: if not is_where_appended: query += ' WHERE' is_where_appended = True query += f" LOWER(_name) LIKE '%{name_query}%'" append_and = True if model_types is not None and len(model_types) > 0: if not is_where_appended: query += ' WHERE' is_where_appended = True if append_and: query += ' AND' query += ' (' append_or = False for model_type in model_types: if append_or: query += ' OR' query += f" model_type='{model_type}'" append_or = True query += ')' append_and = True pass if groups is not None and len(groups) > 0: if not is_where_appended: query += ' WHERE' for group in groups: if append_and: query += ' AND' query += f" LOWER(groups) LIKE '%{group}%'" append_and = True logger.debug(f'query: {query}') cursor = self._connection().cursor() cursor.execute(query) rows = cursor.fetchall() result = [] for row in rows: record = map_row_to_record(row) is_downloaded = bool(record.location) and os.path.exists(record.location) if show_downloaded and is_downloaded: result.append(record) elif show_not_downloaded and not is_downloaded: result.append(record) return result def get_record_by_id(self, id_) -> Record: cursor = self._connection().cursor() cursor.execute('SELECT * FROM Record WHERE id=?', (id_,)) row = cursor.fetchone() return None if row is None else map_row_to_record(row) def get_records_by_group(self, group: str) -> List: cursor = self._connection().cursor() cursor.execute(f"SELECT * FROM Record WHERE LOWER(groups) LIKE '%{group}%'") rows = cursor.fetchall() result = [] for row in rows: result.append(map_row_to_record(row)) return result def get_records_by_query(self, query: str) -> List: cursor = self._connection().cursor() cursor.execute(query) rows = cursor.fetchall() result = [] for row in rows: result.append(map_row_to_record(row)) return result def add_record(self, record: Record): cursor = self._connection().cursor() data = ( record.name, record.model_type.value, record.download_url, record.url, record.download_path, record.download_filename, record.preview_url, record.description, record.positive_prompts, record.negative_prompts, record.sha256_hash, record.md5_hash, record.created_at, ",".join(record.groups), record.subdir, record.location, record.weight ) cursor.execute( """INSERT INTO Record( _name, model_type, download_url, url, download_path, download_filename, preview_url, description, positive_prompts, negative_prompts, sha256_hash, md5_hash, created_at, groups, subdir, location, weight) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", data) self._connection().commit() def update_record(self, record: Record): cursor = self._connection().cursor() data = ( record.name, record.model_type.value, record.download_url, record.url, record.download_path, record.download_filename, record.preview_url, record.description, record.positive_prompts, record.negative_prompts, record.sha256_hash, record.md5_hash, ",".join(record.groups), record.subdir, record.location, record.weight, record.id_ ) cursor.execute( """UPDATE Record SET _name=?, model_type=?, download_url=?, url=?, download_path=?, download_filename=?, preview_url=?, description=?, positive_prompts=?, negative_prompts=?, sha256_hash=?, md5_hash=?, groups=?, subdir=?, location=?, weight=? WHERE id=? """, data ) self._connection().commit() def remove_record(self, _id): cursor = self._connection().cursor() cursor.execute("DELETE FROM Record WHERE id=?", (_id,)) self._connection().commit() def get_available_groups(self) -> List: cursor = self._connection().cursor() cursor.execute('SELECT groups FROM Record') rows = cursor.fetchall() result = [] for row in rows: if row[0]: result.extend(row[0].split(",")) result = list(set(result)) return list(filter(None, result)) def get_all_records_locations(self) -> List: cursor = self._connection().cursor() cursor.execute('SELECT location FROM Record') rows = cursor.fetchall() result = [] for row in rows: if row[0]: result.append(row[0]) return result