from enum import Enum from datetime import datetime from typing import Optional from sqlalchemy import Column, String, Text, Integer, DateTime, LargeBinary, text, func from sqlalchemy.orm import Session from .base import BaseTableManager, Base from ..models import TaskModel class TaskStatus(str, Enum): PENDING = "pending" RUNNING = "running" DONE = "done" FAILED = "failed" class Task(TaskModel): script_params: bytes = None def __init__( self, id: str = "", type: str = "unknown", params: str = "", script_params: bytes = b"", priority: int = None, status: str = TaskStatus.PENDING.value, result: str = None, created_at: Optional[datetime] = None, updated_at: Optional[datetime] = None, ): priority = priority if priority else int(datetime.utcnow().timestamp() * 1000) super().__init__( id=id, type=type, params=params, status=status, priority=priority, result=result, created_at=created_at, updated_at=created_at, ) self.id: str = id self.type: str = type self.params: str = params self.script_params: bytes = script_params self.priority: int = priority self.status: str = status self.result: str = result self.created_at: datetime = created_at self.updated_at: datetime = updated_at class Config(TaskModel.__config__): exclude = ["script_params"] @staticmethod def from_table(table: "TaskTable"): return Task( id=table.id, type=table.type, params=table.params, script_params=table.script_params, priority=table.priority, status=table.status, created_at=table.created_at, updated_at=table.updated_at, ) def to_table(self): return TaskTable( id=self.id, type=self.type, params=self.params, script_params=self.script_params, priority=self.priority, status=self.status, ) class TaskTable(Base): __tablename__ = "task" id = Column(String(64), primary_key=True) type = Column(String(20), nullable=False) # txt2img or img2txt params = Column(Text, nullable=False) # task args script_params = Column(LargeBinary, nullable=False) # script args priority = Column(Integer, nullable=False, default=datetime.now) status = Column( String(20), nullable=False, default="pending" ) # pending, running, done, failed result = Column(Text) # task result created_at = Column( DateTime, nullable=False, server_default=text("(datetime('now'))"), ) updated_at = Column( DateTime, nullable=False, server_default=text("(datetime('now'))"), onupdate=text("(datetime('now'))"), ) def __repr__(self): return f"Task(id={self.id!r}, type={self.type!r}, params={self.params!r}, status={self.status!r}, created_at={self.created_at!r})" class TaskManager(BaseTableManager): def get_task(self, id: str) -> TaskTable | None: session = Session(self.engine) try: task = session.get(TaskTable, id) return Task.from_table(task) if task else None except Exception as e: print(f"Exception getting task from database: {e}") raise e finally: session.close() def get_tasks( self, type: str = None, status: str = None, limit: int = None, offset: int = None, ) -> list[TaskTable]: session = Session(self.engine) try: query = session.query(TaskTable) if type: query = query.filter(TaskTable.type == type) if status: query = query.filter(TaskTable.status == status) query = query.order_by(TaskTable.priority.asc()).order_by( TaskTable.created_at.asc() ) if limit: query = query.limit(limit) if offset: query = query.offset(offset) all = query.all() return [Task.from_table(t) for t in all] except Exception as e: print(f"Exception getting tasks from database: {e}") raise e finally: session.close() def count_tasks( self, type: str = None, status: str = None, ) -> int: session = Session(self.engine) try: query = session.query(TaskTable) if type: query = query.filter(TaskTable.type == type) if status: query = query.filter(TaskTable.status == status) return query.count() except Exception as e: print(f"Exception counting tasks from database: {e}") raise e finally: session.close() def add_task(self, task: Task) -> TaskTable: session = Session(self.engine) try: result = task.to_table() session.add(result) session.commit() return result except Exception as e: print(f"Exception adding task to database: {e}") raise e finally: session.close() def update_task(self, id: str, status: str, result=None) -> TaskTable: session = Session(self.engine) try: task = session.get(TaskTable, id) if task: task.status = status task.result = result session.commit() return task else: raise Exception(f"Task with id {id} not found") except Exception as e: print(f"Exception updating task in database: {e}") raise e finally: session.close() def prioritize_task(self, id: str, priority: int) -> TaskTable: """0 means move to top, -1 means move to bottom, otherwise set the exact priority""" session = Session(self.engine) try: result = session.get(TaskTable, id) if result: if priority == 0: result.priority = self.__get_min_priority() - 1 elif priority == -1: result.priority = int(datetime.utcnow().timestamp() * 1000) else: self.__move_tasks_down(priority) session.execute(text("SELECT 1")) result.priority = priority session.commit() return result else: raise Exception(f"Task with id {id} not found") except Exception as e: print(f"Exception updating task in database: {e}") raise e finally: session.close() def delete_task(self, id: str): session = Session(self.engine) try: result = session.get(TaskTable, id) if result: session.delete(result) session.commit() else: raise Exception(f"Task with id {id} not found") except Exception as e: print(f"Exception deleting task from database: {e}") raise e finally: session.close() def delete_tasks_before(self, before: datetime, all: bool = False): session = Session(self.engine) try: query = session.query(TaskTable).filter(TaskTable.created_at < before) if not all: query = query.filter( TaskTable.status.in_([TaskStatus.DONE, TaskStatus.FAILED]) ) query.delete() session.commit() except Exception as e: print(f"Exception deleting tasks from database: {e}") raise e finally: session.close() def __get_min_priority(self) -> int: session = Session(self.engine) try: min_priority = session.query(func.min(TaskTable.priority)).scalar() return min_priority if min_priority else 0 except Exception as e: print(f"Exception getting min priority from database: {e}") raise e finally: session.close() def __move_tasks_down(self, priority: int): session = Session(self.engine) try: session.query(TaskTable).filter(TaskTable.priority >= priority).update( {TaskTable.priority: TaskTable.priority + 1} ) session.commit() except Exception as e: print(f"Exception moving tasks down in database: {e}") raise e finally: session.close()