sd-webui-agent-scheduler/scripts/api.py

245 lines
8.4 KiB
Python

import threading
from uuid import uuid4
from gradio.routes import App
from modules import shared, progress, script_callbacks
from scripts.db import TaskStatus, task_manager
from scripts.models import (
Txt2ImgApiTaskArgs,
Img2ImgApiTaskArgs,
QueueTaskResponse,
QueueStatusResponse,
HistoryResponse,
TaskModel,
)
from scripts.task_runner import TaskRunner, get_instance
from scripts.helpers import log
from scripts.task_helpers import serialize_api_task_args
task_runner: TaskRunner = None
def regsiter_apis(app: App):
log.info("[AgentScheduler] Registering APIs")
@app.post("/agent-scheduler/v1/queue/txt2img", response_model=QueueTaskResponse)
def queue_txt2img(body: Txt2ImgApiTaskArgs):
params = body.dict()
task_id = str(uuid4())
checkpoint = params.pop("model_hash", None)
task_args = serialize_api_task_args(
params,
is_img2img=False,
checkpoint=checkpoint,
)
task_runner.register_api_task(
task_id, api_task_id=False, is_img2img=False, args=task_args
)
task_runner.execute_pending_tasks_threading()
return QueueTaskResponse(task_id=task_id)
@app.post("/agent-scheduler/v1/queue/img2img", response_model=QueueTaskResponse)
def queue_img2img(body: Img2ImgApiTaskArgs):
params = body.dict()
task_id = str(uuid4())
checkpoint = params.pop("model_hash", None)
task_args = serialize_api_task_args(
params,
is_img2img=True,
checkpoint=checkpoint,
)
task_runner.register_api_task(
task_id, api_task_id=False, is_img2img=True, args=task_args
)
task_runner.execute_pending_tasks_threading()
return QueueTaskResponse(task_id=task_id)
@app.get("/agent-scheduler/v1/queue", response_model=QueueStatusResponse)
def queue_status_api(limit: int = 20, offset: int = 0):
current_task_id = progress.current_task
total_pending_tasks = task_manager.count_tasks(status="pending")
pending_tasks = task_manager.get_tasks(
status=TaskStatus.PENDING, limit=limit, offset=offset
)
parsed_tasks = []
for task in pending_tasks:
task_args = TaskRunner.instance.parse_task_args(
task.params, task.script_params, deserialization=False
)
named_args = task_args.named_args
named_args["checkpoint"] = task_args.checkpoint
task_data = task.dict()
task_data["params"] = named_args
if task.id == current_task_id:
task_data["status"] = "running"
parsed_tasks.append(TaskModel(**task_data))
return QueueStatusResponse(
current_task_id=current_task_id,
pending_tasks=parsed_tasks,
total_pending_tasks=total_pending_tasks,
paused=TaskRunner.instance.paused,
)
@app.get("/agent-scheduler/v1/history", response_model=HistoryResponse)
def history_api(status: str = None, limit: int = 20, offset: int = 0):
bookmarked = True if status == "bookmarked" else None
if not status or status == "all" or bookmarked:
status = [
TaskStatus.DONE,
TaskStatus.FAILED,
TaskStatus.INTERRUPTED,
]
total = task_manager.count_tasks(status=status)
tasks = task_manager.get_tasks(
status=status,
bookmarked=bookmarked,
limit=limit,
offset=offset,
order="desc",
)
parsed_tasks = []
for task in tasks:
task_args = TaskRunner.instance.parse_task_args(
task.params, task.script_params, deserialization=False
)
named_args = task_args.named_args
named_args["checkpoint"] = task_args.checkpoint
task_data = task.dict()
task_data["params"] = named_args
parsed_tasks.append(TaskModel(**task_data))
return HistoryResponse(
total=total,
tasks=parsed_tasks,
)
@app.post("/agent-scheduler/v1/run/{id}")
def run_task(id: str):
if progress.current_task is not None:
if progress.current_task == id:
return {"success": False, "message": f"Task {id} is already running"}
else:
# move task up in queue
task_manager.prioritize_task(id, 0)
return {
"success": True,
"message": f"Task {id} is scheduled to run next",
}
else:
# run task
task = task_manager.get_task(id)
current_thread = threading.Thread(
target=TaskRunner.instance.execute_task,
args=(
task,
lambda: None,
),
)
current_thread.daemon = True
current_thread.start()
return {"success": True, "message": f"Task {id} is executing"}
@app.post("/agent-scheduler/v1/requeue/{id}")
def requeue_task(id: str):
task = task_manager.get_task(id)
if task is None:
return {"success": False, "message": f"Task {id} not found"}
task.id = str(uuid4())
task.result = None
task.status = TaskStatus.PENDING
task.bookmarked = False
task.name = f"Copy of {task.name}" if task.name else None
task_manager.add_task(task)
task_runner.execute_pending_tasks_threading()
return {"success": True, "message": f"Task {id} is requeued"}
@app.post("/agent-scheduler/v1/delete/{id}")
def delete_task(id: str):
if progress.current_task == id:
shared.state.interrupt()
task_runner.interrupted = id
return {"success": True, "message": f"Task {id} is interrupted"}
task_manager.delete_task(id)
return {"success": True, "message": f"Task {id} is deleted"}
@app.post("/agent-scheduler/v1/move/{id}/{over_id}")
def move_task(id: str, over_id: str):
task = task_manager.get_task(id)
if task is None:
return {"success": False, "message": f"Task {id} not found"}
if over_id == "top":
task_manager.prioritize_task(id, 0)
return {"success": True, "message": f"Task {id} is moved to top"}
elif over_id == "bottom":
task_manager.prioritize_task(id, -1)
return {"success": True, "message": f"Task {id} is moved to bottom"}
else:
over_task = task_manager.get_task(over_id)
if over_task is None:
return {"success": False, "message": f"Task {over_id} not found"}
task_manager.prioritize_task(id, over_task.priority)
return {"success": True, "message": f"Task {id} is moved"}
@app.post("/agent-scheduler/v1/bookmark/{id}")
def pin_task(id: str):
task = task_manager.get_task(id)
if task is None:
return {"success": False, "message": f"Task {id} not found"}
task.bookmarked = True
task_manager.update_task(id, bookmarked=True)
return {"success": True, "message": f"Task {id} is bookmarked"}
@app.post("/agent-scheduler/v1/unbookmark/{id}")
def unpin_task(id: str):
task = task_manager.get_task(id)
if task is None:
return {"success": False, "message": f"Task {id} not found"}
task_manager.update_task(id, bookmarked=False)
return {"success": True, "message": f"Task {id} is unbookmarked"}
@app.post("/agent-scheduler/v1/rename/{id}")
def rename_task(id: str, name: str):
task = task_manager.get_task(id)
if task is None:
return {"success": False, "message": f"Task {id} not found"}
task_manager.update_task(id, name=name)
return {"success": True, "message": f"Task {id} is renamed"}
@app.post("/agent-scheduler/v1/pause")
def pause_queue():
shared.opts.queue_paused = True
return {"success": True, "message": f"Queue is paused"}
@app.post("/agent-scheduler/v1/resume")
def resume_queue():
shared.opts.queue_paused = False
TaskRunner.instance.execute_pending_tasks_threading()
return {"success": True, "message": f"Queue is resumed"}
def on_app_started(block, app: App):
global task_runner
task_runner = get_instance(block)
regsiter_apis(app)
script_callbacks.on_app_started(on_app_started)