feat: new API to get task position in queue
parent
ab9491c2d5
commit
6a0cf8be75
|
|
@ -135,6 +135,7 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
|
|||
pending_tasks = task_manager.get_tasks(
|
||||
status=TaskStatus.PENDING, limit=limit, offset=offset
|
||||
)
|
||||
position = offset
|
||||
parsed_tasks = []
|
||||
for task in pending_tasks:
|
||||
params = format_task_args(task)
|
||||
|
|
@ -143,7 +144,9 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
|
|||
if task.id == current_task_id:
|
||||
task_data["status"] = "running"
|
||||
|
||||
task_data["position"] = position
|
||||
parsed_tasks.append(TaskModel(**task_data))
|
||||
position += 1
|
||||
|
||||
return QueueStatusResponse(
|
||||
current_task_id=current_task_id,
|
||||
|
|
@ -193,9 +196,24 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
|
|||
task_data["params"] = params
|
||||
if task.id == progress.current_task:
|
||||
task_data["status"] = "running"
|
||||
if task_data["status"] == TaskStatus.PENDING:
|
||||
task_data["position"] = task_manager.get_task_position(id)
|
||||
|
||||
return {"success": True, "data": TaskModel(**task_data)}
|
||||
|
||||
@app.get("/agent-scheduler/v1/task/{id}/position")
|
||||
def get_task_position(id: str):
|
||||
task = task_manager.get_task(id)
|
||||
if task is None:
|
||||
return {"success": False, "message": "Task not found"}
|
||||
|
||||
position = (
|
||||
None
|
||||
if task.status != TaskStatus.PENDING
|
||||
else task_manager.get_task_position(id)
|
||||
)
|
||||
return {"success": True, "data": {"status": task.status, "position": position}}
|
||||
|
||||
@app.put("/agent-scheduler/v1/task/{id}")
|
||||
def update_task(id: str, body: UpdateTaskArgs):
|
||||
task = task_manager.get_task(id)
|
||||
|
|
|
|||
|
|
@ -117,6 +117,25 @@ class TaskManager(BaseTableManager):
|
|||
finally:
|
||||
session.close()
|
||||
|
||||
def get_task_position(self, id: str) -> int:
|
||||
session = Session(self.engine)
|
||||
try:
|
||||
task = session.get(TaskTable, id)
|
||||
if task:
|
||||
return (
|
||||
session.query(func.count(TaskTable.id))
|
||||
.filter(TaskTable.status == TaskStatus.PENDING)
|
||||
.filter(TaskTable.priority < task.priority)
|
||||
.scalar()
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Task with id {id} not found")
|
||||
except Exception as e:
|
||||
print(f"Exception getting task position from database: {e}")
|
||||
raise e
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def get_tasks(
|
||||
self,
|
||||
type: str = None,
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ class TaskModel(BaseModel):
|
|||
title="Task Parameters", description="The parameters of the task in JSON format"
|
||||
)
|
||||
priority: Optional[int] = Field(title="Task Priority")
|
||||
position: Optional[int] = Field(title="Task Position")
|
||||
result: Optional[str] = Field(
|
||||
title="Task Result", description="The result of the task in JSON format"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -134,7 +134,6 @@ class TaskRunner:
|
|||
}
|
||||
)
|
||||
script_params = serialize_script_args(script_args)
|
||||
print("__serialize_ui_task_args", named_args, script_args)
|
||||
|
||||
return (params, script_params)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue