357 lines
12 KiB
Python
357 lines
12 KiB
Python
import os
|
|
import sqlite3
|
|
import threading
|
|
from typing import List
|
|
from modules import shared
|
|
|
|
from scripts.mo.data.storage import Storage
|
|
from scripts.mo.environment import env, logger
|
|
from scripts.mo.models import Record, ModelType
|
|
|
|
_DB_FILE = 'database.sqlite'
|
|
_DB_VERSION = 6
|
|
_DB_TIMEOUT = 30
|
|
|
|
|
|
def map_row_to_record(row) -> Record:
|
|
return Record(
|
|
id_=row[0],
|
|
name=row[1],
|
|
model_type=ModelType.by_value(row[2]),
|
|
download_url=row[3],
|
|
url=row[4],
|
|
download_path=row[5],
|
|
download_filename=row[6],
|
|
preview_url=row[7],
|
|
description=row[8],
|
|
positive_prompts=row[9],
|
|
negative_prompts=row[10],
|
|
sha256_hash=row[11],
|
|
md5_hash=row[12],
|
|
created_at=row[13],
|
|
groups=row[14].split(',') if row[14] else [],
|
|
subdir=row[15],
|
|
location=row[16],
|
|
weight=row[17]
|
|
)
|
|
|
|
|
|
class SQLiteStorage(Storage):
|
|
|
|
def __init__(self):
|
|
self.local = threading.local()
|
|
self._initialize()
|
|
|
|
def _connection(self):
|
|
if not hasattr(self.local, "connection"):
|
|
mo_database_dir = getattr(shared.cmd_opts, "mo_database_dir")
|
|
database_dir = mo_database_dir if mo_database_dir is not None else env.script_dir
|
|
db_file_path = os.path.join(database_dir, _DB_FILE)
|
|
self.local.connection = sqlite3.connect(db_file_path, _DB_TIMEOUT)
|
|
return self.local.connection
|
|
|
|
def _initialize(self):
|
|
cursor = self._connection().cursor()
|
|
|
|
cursor.execute('''CREATE TABLE IF NOT EXISTS Record
|
|
(id INTEGER PRIMARY KEY,
|
|
_name TEXT,
|
|
model_type TEXT,
|
|
download_url TEXT,
|
|
url TEXT DEFAULT '',
|
|
download_path TEXT DEFAULT '',
|
|
download_filename TEXT DEFAULT '',
|
|
preview_url TEXT DEFAULT '',
|
|
description TEXT DEFAULT '',
|
|
positive_prompts TEXT DEFAULT '',
|
|
negative_prompts TEXT DEFAULT '',
|
|
sha256_hash TEXT DEFAULT '',
|
|
md5_hash TEXT DEFAULT '',
|
|
created_at INTEGER DEFAULT 0,
|
|
groups TEXT DEFAULT '',
|
|
subdir TEXT DEFAULT '',
|
|
location TEXT DEFAULT '',
|
|
weight REAL DEFAULT 1)
|
|
''')
|
|
|
|
cursor.execute(f'''CREATE TABLE IF NOT EXISTS Version
|
|
(version INTEGER DEFAULT {_DB_VERSION})''')
|
|
self._connection().commit()
|
|
self._check_database_version()
|
|
|
|
def _check_database_version(self):
|
|
cursor = self._connection().cursor()
|
|
cursor.execute('SELECT * FROM Version ', )
|
|
row = cursor.fetchone()
|
|
|
|
if row is None:
|
|
cursor.execute(f'INSERT INTO Version VALUES ({_DB_VERSION})')
|
|
self._connection().commit()
|
|
|
|
version = _DB_VERSION if row is None else row[0]
|
|
if version != _DB_VERSION:
|
|
self._run_migration(version)
|
|
|
|
def _run_migration(self, current_version):
|
|
for ver in range(current_version, _DB_VERSION):
|
|
if ver == 1:
|
|
self._migrate_1_to_2()
|
|
elif ver == 2:
|
|
self._migrate_2_to_3()
|
|
elif ver == 3:
|
|
self._migrate_3_to_4()
|
|
elif ver == 4:
|
|
self._migrate_4_to_5()
|
|
elif ver == 5:
|
|
self._migrage_5_to_6()
|
|
else:
|
|
raise Exception(f'Missing SQLite migration from {ver} to {_DB_VERSION}')
|
|
|
|
def _migrate_1_to_2(self):
|
|
cursor = self._connection().cursor()
|
|
cursor.execute('ALTER TABLE Record ADD COLUMN created_at INTEGER DEFAULT 0;')
|
|
cursor.execute("DELETE FROM Version")
|
|
cursor.execute('INSERT INTO Version VALUES (2)')
|
|
self._connection().commit()
|
|
|
|
def _migrate_2_to_3(self):
|
|
cursor = self._connection().cursor()
|
|
cursor.execute("ALTER TABLE Record ADD COLUMN groups TEXT DEFAULT '';")
|
|
cursor.execute("DELETE FROM Version")
|
|
cursor.execute('INSERT INTO Version VALUES (3)')
|
|
self._connection().commit()
|
|
|
|
def _migrate_3_to_4(self):
|
|
cursor = self._connection().cursor()
|
|
cursor.execute("ALTER TABLE Record RENAME COLUMN model_hash TO sha256_hash;")
|
|
cursor.execute("ALTER TABLE Record ADD COLUMN subdir TEXT DEFAULT '';")
|
|
cursor.execute("DELETE FROM Version")
|
|
cursor.execute('INSERT INTO Version VALUES (4)')
|
|
self._connection().commit()
|
|
|
|
def _migrate_4_to_5(self):
|
|
cursor = self._connection().cursor()
|
|
cursor.execute("ALTER TABLE Record ADD COLUMN location TEXT DEFAULT '';")
|
|
cursor.execute("DELETE FROM Version")
|
|
cursor.execute('INSERT INTO Version VALUES (5)')
|
|
self._connection().commit()
|
|
|
|
def _migrage_5_to_6(self):
|
|
cursor = self._connection().cursor()
|
|
cursor.execute("ALTER TABLE Record ADD COLUMN weight REAL DEFAULT 1;")
|
|
cursor.execute("DELETE FROM Version")
|
|
cursor.execute('INSERT INTO Version VALUES (6)')
|
|
self._connection().commit()
|
|
|
|
def get_all_records(self) -> List:
|
|
cursor = self._connection().cursor()
|
|
cursor.execute('SELECT * FROM Record')
|
|
rows = cursor.fetchall()
|
|
result = []
|
|
for row in rows:
|
|
result.append(map_row_to_record(row))
|
|
return result
|
|
|
|
def query_records(self, name_query: str = None, groups=None, model_types=None, show_downloaded=True,
|
|
show_not_downloaded=True) -> List:
|
|
|
|
query = 'SELECT * FROM Record'
|
|
|
|
is_where_appended = False
|
|
append_and = False
|
|
|
|
if name_query is not None and name_query:
|
|
if not is_where_appended:
|
|
query += ' WHERE'
|
|
is_where_appended = True
|
|
|
|
query += f" LOWER(_name) LIKE '%{name_query}%'"
|
|
append_and = True
|
|
|
|
if model_types is not None and len(model_types) > 0:
|
|
if not is_where_appended:
|
|
query += ' WHERE'
|
|
is_where_appended = True
|
|
|
|
if append_and:
|
|
query += ' AND'
|
|
|
|
query += ' ('
|
|
append_or = False
|
|
for model_type in model_types:
|
|
if append_or:
|
|
query += ' OR'
|
|
query += f" model_type='{model_type}'"
|
|
append_or = True
|
|
|
|
query += ')'
|
|
|
|
append_and = True
|
|
pass
|
|
|
|
if groups is not None and len(groups) > 0:
|
|
if not is_where_appended:
|
|
query += ' WHERE'
|
|
|
|
for group in groups:
|
|
if append_and:
|
|
query += ' AND'
|
|
query += f" LOWER(groups) LIKE '%{group}%'"
|
|
append_and = True
|
|
|
|
logger.debug(f'query: {query}')
|
|
cursor = self._connection().cursor()
|
|
cursor.execute(query)
|
|
rows = cursor.fetchall()
|
|
result = []
|
|
for row in rows:
|
|
record = map_row_to_record(row)
|
|
is_downloaded = bool(record.location) and os.path.exists(record.location)
|
|
|
|
if show_downloaded and is_downloaded:
|
|
result.append(record)
|
|
elif show_not_downloaded and not is_downloaded:
|
|
result.append(record)
|
|
|
|
return result
|
|
|
|
def get_record_by_id(self, id_) -> Record:
|
|
cursor = self._connection().cursor()
|
|
cursor.execute('SELECT * FROM Record WHERE id=?', (id_,))
|
|
row = cursor.fetchone()
|
|
return None if row is None else map_row_to_record(row)
|
|
|
|
def get_records_by_group(self, group: str) -> List:
|
|
cursor = self._connection().cursor()
|
|
cursor.execute(f"SELECT * FROM Record WHERE LOWER(groups) LIKE '%{group}%'")
|
|
rows = cursor.fetchall()
|
|
result = []
|
|
for row in rows:
|
|
result.append(map_row_to_record(row))
|
|
return result
|
|
|
|
def get_records_by_query(self, query: str) -> List:
|
|
cursor = self._connection().cursor()
|
|
cursor.execute(query)
|
|
rows = cursor.fetchall()
|
|
result = []
|
|
for row in rows:
|
|
result.append(map_row_to_record(row))
|
|
return result
|
|
|
|
def add_record(self, record: Record):
|
|
cursor = self._connection().cursor()
|
|
data = (
|
|
record.name,
|
|
record.model_type.value,
|
|
record.download_url,
|
|
record.url,
|
|
record.download_path,
|
|
record.download_filename,
|
|
record.preview_url,
|
|
record.description,
|
|
record.positive_prompts,
|
|
record.negative_prompts,
|
|
record.sha256_hash,
|
|
record.md5_hash,
|
|
record.created_at,
|
|
",".join(record.groups),
|
|
record.subdir,
|
|
record.location,
|
|
record.weight
|
|
)
|
|
cursor.execute(
|
|
"""INSERT INTO Record(
|
|
_name,
|
|
model_type,
|
|
download_url,
|
|
url,
|
|
download_path,
|
|
download_filename,
|
|
preview_url,
|
|
description,
|
|
positive_prompts,
|
|
negative_prompts,
|
|
sha256_hash,
|
|
md5_hash,
|
|
created_at,
|
|
groups,
|
|
subdir,
|
|
location,
|
|
weight) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
data)
|
|
self._connection().commit()
|
|
|
|
def update_record(self, record: Record):
|
|
cursor = self._connection().cursor()
|
|
data = (
|
|
record.name,
|
|
record.model_type.value,
|
|
record.download_url,
|
|
record.url,
|
|
record.download_path,
|
|
record.download_filename,
|
|
record.preview_url,
|
|
record.description,
|
|
record.positive_prompts,
|
|
record.negative_prompts,
|
|
record.sha256_hash,
|
|
record.md5_hash,
|
|
",".join(record.groups),
|
|
record.subdir,
|
|
record.location,
|
|
record.weight,
|
|
record.id_
|
|
)
|
|
cursor.execute(
|
|
"""UPDATE Record SET
|
|
_name=?,
|
|
model_type=?,
|
|
download_url=?,
|
|
url=?,
|
|
download_path=?,
|
|
download_filename=?,
|
|
preview_url=?,
|
|
description=?,
|
|
positive_prompts=?,
|
|
negative_prompts=?,
|
|
sha256_hash=?,
|
|
md5_hash=?,
|
|
groups=?,
|
|
subdir=?,
|
|
location=?,
|
|
weight=?
|
|
WHERE id=?
|
|
""", data
|
|
)
|
|
|
|
self._connection().commit()
|
|
|
|
def remove_record(self, _id):
|
|
cursor = self._connection().cursor()
|
|
cursor.execute("DELETE FROM Record WHERE id=?", (_id,))
|
|
self._connection().commit()
|
|
|
|
def get_available_groups(self) -> List:
|
|
cursor = self._connection().cursor()
|
|
cursor.execute('SELECT groups FROM Record')
|
|
rows = cursor.fetchall()
|
|
result = []
|
|
for row in rows:
|
|
if row[0]:
|
|
result.extend(row[0].split(","))
|
|
|
|
result = list(set(result))
|
|
return list(filter(None, result))
|
|
|
|
def get_all_records_locations(self) -> List:
|
|
cursor = self._connection().cursor()
|
|
cursor.execute('SELECT location FROM Record')
|
|
rows = cursor.fetchall()
|
|
result = []
|
|
for row in rows:
|
|
if row[0]:
|
|
result.append(row[0])
|
|
|
|
return result
|