290 lines
8.7 KiB
Python
290 lines
8.7 KiB
Python
from enum import Enum
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
from sqlalchemy import Column, String, Text, Integer, DateTime, LargeBinary, text, func
|
|
from sqlalchemy.orm import Session
|
|
|
|
from .base import BaseTableManager, Base
|
|
from ..models import TaskModel
|
|
|
|
|
|
class TaskStatus(str, Enum):
|
|
PENDING = "pending"
|
|
RUNNING = "running"
|
|
DONE = "done"
|
|
FAILED = "failed"
|
|
|
|
|
|
class Task(TaskModel):
|
|
script_params: bytes = None
|
|
|
|
def __init__(
|
|
self,
|
|
id: str = "",
|
|
type: str = "unknown",
|
|
params: str = "",
|
|
script_params: bytes = b"",
|
|
priority: int = None,
|
|
status: str = TaskStatus.PENDING.value,
|
|
result: str = None,
|
|
created_at: Optional[datetime] = None,
|
|
updated_at: Optional[datetime] = None,
|
|
):
|
|
priority = priority if priority else int(datetime.utcnow().timestamp() * 1000)
|
|
|
|
super().__init__(
|
|
id=id,
|
|
type=type,
|
|
params=params,
|
|
status=status,
|
|
priority=priority,
|
|
result=result,
|
|
created_at=created_at,
|
|
updated_at=created_at,
|
|
)
|
|
self.id: str = id
|
|
self.type: str = type
|
|
self.params: str = params
|
|
self.script_params: bytes = script_params
|
|
self.priority: int = priority
|
|
self.status: str = status
|
|
self.result: str = result
|
|
self.created_at: datetime = created_at
|
|
self.updated_at: datetime = updated_at
|
|
|
|
class Config(TaskModel.__config__):
|
|
exclude = ["script_params"]
|
|
|
|
@staticmethod
|
|
def from_table(table: "TaskTable"):
|
|
return Task(
|
|
id=table.id,
|
|
type=table.type,
|
|
params=table.params,
|
|
script_params=table.script_params,
|
|
priority=table.priority,
|
|
status=table.status,
|
|
created_at=table.created_at,
|
|
updated_at=table.updated_at,
|
|
)
|
|
|
|
def to_table(self):
|
|
return TaskTable(
|
|
id=self.id,
|
|
type=self.type,
|
|
params=self.params,
|
|
script_params=self.script_params,
|
|
priority=self.priority,
|
|
status=self.status,
|
|
)
|
|
|
|
|
|
class TaskTable(Base):
|
|
__tablename__ = "task"
|
|
|
|
id = Column(String(64), primary_key=True)
|
|
type = Column(String(20), nullable=False) # txt2img or img2txt
|
|
params = Column(Text, nullable=False) # task args
|
|
script_params = Column(LargeBinary, nullable=False) # script args
|
|
priority = Column(Integer, nullable=False, default=datetime.now)
|
|
status = Column(
|
|
String(20), nullable=False, default="pending"
|
|
) # pending, running, done, failed
|
|
result = Column(Text) # task result
|
|
created_at = Column(
|
|
DateTime,
|
|
nullable=False,
|
|
server_default=text("(datetime('now'))"),
|
|
)
|
|
updated_at = Column(
|
|
DateTime,
|
|
nullable=False,
|
|
server_default=text("(datetime('now'))"),
|
|
onupdate=text("(datetime('now'))"),
|
|
)
|
|
|
|
def __repr__(self):
|
|
return f"Task(id={self.id!r}, type={self.type!r}, params={self.params!r}, status={self.status!r}, created_at={self.created_at!r})"
|
|
|
|
|
|
class TaskManager(BaseTableManager):
|
|
def get_task(self, id: str) -> TaskTable | None:
|
|
session = Session(self.engine)
|
|
try:
|
|
task = session.get(TaskTable, id)
|
|
|
|
return Task.from_table(task) if task else None
|
|
except Exception as e:
|
|
print(f"Exception getting task from database: {e}")
|
|
raise e
|
|
finally:
|
|
session.close()
|
|
|
|
def get_tasks(
|
|
self,
|
|
type: str = None,
|
|
status: str = None,
|
|
limit: int = None,
|
|
offset: int = None,
|
|
) -> list[TaskTable]:
|
|
session = Session(self.engine)
|
|
try:
|
|
query = session.query(TaskTable)
|
|
if type:
|
|
query = query.filter(TaskTable.type == type)
|
|
|
|
if status:
|
|
query = query.filter(TaskTable.status == status)
|
|
|
|
query = query.order_by(TaskTable.priority.asc()).order_by(
|
|
TaskTable.created_at.asc()
|
|
)
|
|
|
|
if limit:
|
|
query = query.limit(limit)
|
|
|
|
if offset:
|
|
query = query.offset(offset)
|
|
|
|
all = query.all()
|
|
return [Task.from_table(t) for t in all]
|
|
except Exception as e:
|
|
print(f"Exception getting tasks from database: {e}")
|
|
raise e
|
|
finally:
|
|
session.close()
|
|
|
|
def count_tasks(
|
|
self,
|
|
type: str = None,
|
|
status: str = None,
|
|
) -> int:
|
|
session = Session(self.engine)
|
|
try:
|
|
query = session.query(TaskTable)
|
|
if type:
|
|
query = query.filter(TaskTable.type == type)
|
|
|
|
if status:
|
|
query = query.filter(TaskTable.status == status)
|
|
|
|
return query.count()
|
|
except Exception as e:
|
|
print(f"Exception counting tasks from database: {e}")
|
|
raise e
|
|
finally:
|
|
session.close()
|
|
|
|
def add_task(self, task: Task) -> TaskTable:
|
|
session = Session(self.engine)
|
|
try:
|
|
result = task.to_table()
|
|
session.add(result)
|
|
session.commit()
|
|
return result
|
|
except Exception as e:
|
|
print(f"Exception adding task to database: {e}")
|
|
raise e
|
|
finally:
|
|
session.close()
|
|
|
|
def update_task(self, id: str, status: str, result=None) -> TaskTable:
|
|
session = Session(self.engine)
|
|
try:
|
|
task = session.get(TaskTable, id)
|
|
if task:
|
|
task.status = status
|
|
task.result = result
|
|
session.commit()
|
|
return task
|
|
else:
|
|
raise Exception(f"Task with id {id} not found")
|
|
except Exception as e:
|
|
print(f"Exception updating task in database: {e}")
|
|
raise e
|
|
finally:
|
|
session.close()
|
|
|
|
def prioritize_task(self, id: str, priority: int) -> TaskTable:
|
|
"""0 means move to top, -1 means move to bottom, otherwise set the exact priority"""
|
|
|
|
session = Session(self.engine)
|
|
try:
|
|
result = session.get(TaskTable, id)
|
|
if result:
|
|
if priority == 0:
|
|
result.priority = self.__get_min_priority() - 1
|
|
elif priority == -1:
|
|
result.priority = int(datetime.utcnow().timestamp() * 1000)
|
|
else:
|
|
self.__move_tasks_down(priority)
|
|
session.execute(text("SELECT 1"))
|
|
result.priority = priority
|
|
|
|
session.commit()
|
|
return result
|
|
else:
|
|
raise Exception(f"Task with id {id} not found")
|
|
except Exception as e:
|
|
print(f"Exception updating task in database: {e}")
|
|
raise e
|
|
finally:
|
|
session.close()
|
|
|
|
def delete_task(self, id: str):
|
|
session = Session(self.engine)
|
|
try:
|
|
result = session.get(TaskTable, id)
|
|
if result:
|
|
session.delete(result)
|
|
session.commit()
|
|
else:
|
|
raise Exception(f"Task with id {id} not found")
|
|
except Exception as e:
|
|
print(f"Exception deleting task from database: {e}")
|
|
raise e
|
|
finally:
|
|
session.close()
|
|
|
|
def delete_tasks_before(self, before: datetime, all: bool = False):
|
|
session = Session(self.engine)
|
|
try:
|
|
query = session.query(TaskTable).filter(TaskTable.created_at < before)
|
|
if not all:
|
|
query = query.filter(
|
|
TaskTable.status.in_([TaskStatus.DONE, TaskStatus.FAILED])
|
|
)
|
|
|
|
query.delete()
|
|
session.commit()
|
|
except Exception as e:
|
|
print(f"Exception deleting tasks from database: {e}")
|
|
raise e
|
|
finally:
|
|
session.close()
|
|
|
|
def __get_min_priority(self) -> int:
|
|
session = Session(self.engine)
|
|
try:
|
|
min_priority = session.query(func.min(TaskTable.priority)).scalar()
|
|
return min_priority if min_priority else 0
|
|
except Exception as e:
|
|
print(f"Exception getting min priority from database: {e}")
|
|
raise e
|
|
finally:
|
|
session.close()
|
|
|
|
def __move_tasks_down(self, priority: int):
|
|
session = Session(self.engine)
|
|
try:
|
|
session.query(TaskTable).filter(TaskTable.priority >= priority).update(
|
|
{TaskTable.priority: TaskTable.priority + 1}
|
|
)
|
|
session.commit()
|
|
except Exception as e:
|
|
print(f"Exception moving tasks down in database: {e}")
|
|
raise e
|
|
finally:
|
|
session.close()
|