stable-diffusion-webui-text.../scripts/t2p/prompt_generator/database_loader.py

101 lines
3.3 KiB
Python

import os
import re
import csv
from typing import Dict
import numpy as np
import scripts.t2p.settings as settings
class Database:
def __init__(self, database_path: str, re_filename: re.Pattern):
self.read_files(database_path, re_filename)
def read_files(self, database_path: str, re_filename: re.Pattern):
self.clear()
self.database_path = database_path
fn, _ = os.path.splitext(os.path.basename(database_path))
m = re_filename.match(fn)
self.size_name = m.group(1)
self.model_name = m.group(2)
if self.model_name not in settings.TOKENIZER_NAMES:
print(f'[text2prompt] Cannot use database in {database_path}; Incompatible model name "{self.model_name}"')
self.clear()
return
tag_path = os.path.join(os.path.dirname(database_path), f'{self.size_name}_tags.txt')
if not os.path.isfile(tag_path):
print(f'[text2prompt] Cannot use database in {database_path}; No tag file exists')
self.clear()
return
with open(tag_path, mode='r', encoding='utf8', newline='\n') as f:
self.tags = [l.strip() for l in f.readlines()]
tag_idx_path = os.path.join(os.path.dirname(database_path), f'{self.size_name}_tagidx.csv')
if not os.path.isfile(tag_idx_path):
print(f'[text2prompt] Cannot read tag indices file. Tag count filter cannot be used.')
else:
with open(tag_idx_path, mode='r', encoding='utf8', newline='') as f:
cr = csv.reader(f)
for row in cr:
self.tag_idx.append((int(row[0]), int(row[1])))
self.tag_idx.sort(key=lambda t : t[0])
self.tag_idx = [(0, len(self.tags) - 1)] + self.tag_idx
def clear(self):
self.database_path = ''
self.model_name = ''
self.size_name = ''
self.tag_idx = []
self.tags = []
self.data: np.ndarray = None
def ready_to_load(self):
return self.database_path \
and self.model_name \
and self.size_name \
and self.tags \
and self.tag_idx
def loaded(self):
return self.data is not None
def load(self):
if not self.ready_to_load(): return None
if not self.loaded():
self.data = np.load(self.database_path)['db']
return self
def name(self):
return f'{self.model_name} : {self.size_name}'
class DatabaseLoader:
def __init__(self, path: str, re_filename: re.Pattern):
self.datas: Dict[str, Database] = dict()
self.preload(path, re_filename)
def preload(self, path: str, re_filename: re.Pattern):
dirs = os.listdir(path)
for d in dirs:
filepath = os.path.join(path, d)
if not os.path.isfile(filepath): continue
_, ext = os.path.splitext(filepath)
if ext == '.npz':
ds = Database(filepath, re_filename)
self.datas[ds.name()] = ds
print('[text2prompt] Following databases are available:')
for name in sorted(self.datas.keys()):
print(f' {name}')
def load(self, database_name: str):
database = self.datas.get(database_name)
return database.load() if database else None