sd-model-organizer/scripts/mo/data/firebase_storage.py

101 lines
3.5 KiB
Python

import os.path
from typing import List
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
from google.cloud.firestore_v1 import CollectionReference
from scripts.mo.data.storage import Storage, map_dict_to_record, map_record_to_dict
from scripts.mo.environment import env
from scripts.mo.models import Record
FIREBASE_APP_NAME = "sd-model-organizer-app"
def _filter_download(record: Record, show_downloaded, show_not_downloaded):
is_downloaded = bool(record.location) and os.path.exists(record.location)
return (show_downloaded and is_downloaded) or (show_not_downloaded and not is_downloaded)
class FirebaseStorage(Storage):
def __init__(self):
if not firebase_admin._apps:
cred = credentials.Certificate(os.path.join(env.script_dir, "service-account-file.json"))
self.app = firebase_admin.initialize_app(cred, name=FIREBASE_APP_NAME)
else:
self.app = firebase_admin.get_app(name=FIREBASE_APP_NAME)
self.firestore_client = firestore.client(app=self.app)
def _records(self) -> CollectionReference:
return self.firestore_client.collection('records')
def get_all_records(self) -> List:
record_refs = self._records().stream()
records = []
for ref in record_refs:
records.append(map_dict_to_record(ref.id, ref.to_dict()))
return records
def query_records(self, name_query=None, groups=None, model_types=None, show_downloaded=None,
show_not_downloaded=None) -> List:
query_ref = self._records()
if model_types is not None and model_types:
query_ref = query_ref.where('model_type', 'in', model_types)
records = []
for ref in query_ref.stream():
records.append(map_dict_to_record(ref.id, ref.to_dict()))
if name_query is not None and name_query:
records = [record for record in records if name_query.lower() in record.name.lower()]
if groups is not None and len(groups) > 0:
records = [item for item in records if all(val in item.groups for val in groups)]
records = list(filter(lambda r: _filter_download(r, show_downloaded, show_not_downloaded), records))
return records
def get_record_by_id(self, _id) -> Record:
doc = self._records().document(_id).get()
return map_dict_to_record(doc.id, doc.to_dict())
def add_record(self, record: Record):
self._records().add(map_record_to_dict(record))
def update_record(self, record: Record):
ref = self._records().document(record.id_)
ref.update(map_record_to_dict(record))
def remove_record(self, _id):
self._records().document(_id).delete()
def get_available_groups(self) -> List:
records = self.get_all_records()
groups = []
for record in records:
if len(record.groups) > 0:
groups.extend(record.groups)
return list(set(groups))
def get_records_by_group(self, group: str) -> List:
col_ref = self._records()
query_ref = col_ref.where('group', 'array_contains', f'%{group}%')
records = []
for ref in query_ref.stream():
records.append(map_dict_to_record(ref.id, ref.to_dict()))
return records
def get_all_records_locations(self) -> List:
records = self.get_all_records()
locations = []
for record in records:
if record.location:
locations.append(record.location)
return list(set(locations))