feat: new API to get task position in queue

pull/112/head
Tung Nguyen 2023-08-10 04:35:58 +07:00
parent ab9491c2d5
commit 6a0cf8be75
4 changed files with 38 additions and 1 deletions

View File

@ -135,6 +135,7 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
pending_tasks = task_manager.get_tasks(
status=TaskStatus.PENDING, limit=limit, offset=offset
)
position = offset
parsed_tasks = []
for task in pending_tasks:
params = format_task_args(task)
@ -143,7 +144,9 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
if task.id == current_task_id:
task_data["status"] = "running"
task_data["position"] = position
parsed_tasks.append(TaskModel(**task_data))
position += 1
return QueueStatusResponse(
current_task_id=current_task_id,
@ -193,9 +196,24 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
task_data["params"] = params
if task.id == progress.current_task:
task_data["status"] = "running"
if task_data["status"] == TaskStatus.PENDING:
task_data["position"] = task_manager.get_task_position(id)
return {"success": True, "data": TaskModel(**task_data)}
@app.get("/agent-scheduler/v1/task/{id}/position")
def get_task_position(id: str):
task = task_manager.get_task(id)
if task is None:
return {"success": False, "message": "Task not found"}
position = (
None
if task.status != TaskStatus.PENDING
else task_manager.get_task_position(id)
)
return {"success": True, "data": {"status": task.status, "position": position}}
@app.put("/agent-scheduler/v1/task/{id}")
def update_task(id: str, body: UpdateTaskArgs):
task = task_manager.get_task(id)

View File

@ -117,6 +117,25 @@ class TaskManager(BaseTableManager):
finally:
session.close()
def get_task_position(self, id: str) -> int:
session = Session(self.engine)
try:
task = session.get(TaskTable, id)
if task:
return (
session.query(func.count(TaskTable.id))
.filter(TaskTable.status == TaskStatus.PENDING)
.filter(TaskTable.priority < task.priority)
.scalar()
)
else:
raise Exception(f"Task with id {id} not found")
except Exception as e:
print(f"Exception getting task position from database: {e}")
raise e
finally:
session.close()
def get_tasks(
self,
type: str = None,

View File

@ -41,6 +41,7 @@ class TaskModel(BaseModel):
title="Task Parameters", description="The parameters of the task in JSON format"
)
priority: Optional[int] = Field(title="Task Priority")
position: Optional[int] = Field(title="Task Position")
result: Optional[str] = Field(
title="Task Result", description="The result of the task in JSON format"
)

View File

@ -134,7 +134,6 @@ class TaskRunner:
}
)
script_params = serialize_script_args(script_args)
print("__serialize_ui_task_args", named_args, script_args)
return (params, script_params)