feat: import/export
parent
52f8590a03
commit
b6965ad05d
|
|
@ -1,6 +1,7 @@
|
|||
import io
|
||||
import os
|
||||
import json
|
||||
import base64
|
||||
import requests
|
||||
import threading
|
||||
from uuid import uuid4
|
||||
|
|
@ -15,6 +16,7 @@ from fastapi import Depends
|
|||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from fastapi.exceptions import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from modules import shared, progress, sd_models, sd_samplers
|
||||
|
||||
|
|
@ -191,6 +193,39 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
|
|||
paused=TaskRunner.instance.paused,
|
||||
)
|
||||
|
||||
@app.get("/agent-scheduler/v1/export")
|
||||
def export_queue(limit: int = 1000, offset: int = 0):
|
||||
pending_tasks = task_manager.get_tasks(status=TaskStatus.PENDING, limit=limit, offset=offset)
|
||||
pending_tasks = [Task.from_table(t).to_json() for t in pending_tasks]
|
||||
return pending_tasks
|
||||
|
||||
class StringRequestBody(BaseModel):
|
||||
content: str
|
||||
|
||||
@app.post("/agent-scheduler/v1/import")
|
||||
def import_queue(queue: StringRequestBody):
|
||||
try:
|
||||
objList = json.loads(queue.content)
|
||||
taskList: List[Task] = []
|
||||
for obj in objList:
|
||||
if "id" not in obj or not obj["id"] or obj["id"] == "":
|
||||
obj["id"] = str(uuid4())
|
||||
obj["result"] = None
|
||||
obj["status"] = TaskStatus.PENDING
|
||||
task = Task.from_json(obj)
|
||||
taskList.append(task)
|
||||
|
||||
for task in taskList:
|
||||
exists = task_manager.get_task(task.id)
|
||||
if exists:
|
||||
task_manager.update_task(task)
|
||||
else:
|
||||
task_manager.add_task(task)
|
||||
return {"success": True, "message": "Queue imported"}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return {"success": False, "message": "Import Failed"}
|
||||
|
||||
@app.get("/agent-scheduler/v1/history", response_model=HistoryResponse, dependencies=deps)
|
||||
def history_api(status: str = None, limit: int = 20, offset: int = 0):
|
||||
bookmarked = True if status == "bookmarked" else None
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import json
|
||||
import base64
|
||||
from enum import Enum
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union, List
|
||||
from typing import Optional, Union, List, Dict
|
||||
|
||||
from sqlalchemy import (
|
||||
TypeDecorator,
|
||||
|
|
@ -89,6 +91,40 @@ class Task(TaskModel):
|
|||
bookmarked=self.bookmarked,
|
||||
)
|
||||
|
||||
def from_json(json_obj: Dict):
|
||||
return Task(
|
||||
id=json_obj.get("id"),
|
||||
api_task_id=json_obj.get("api_task_id", None),
|
||||
api_task_callback=json_obj.get("api_task_callback", None),
|
||||
name=json_obj.get("name", None),
|
||||
type=json_obj.get("type"),
|
||||
status=json_obj.get("status", TaskStatus.PENDING),
|
||||
params=json.dumps(json_obj.get("params")),
|
||||
script_params=base64.b64decode(json_obj.get("script_params")),
|
||||
priority=json_obj.get("priority", int(datetime.now(timezone.utc).timestamp() * 1000)),
|
||||
result=json_obj.get("result", None),
|
||||
bookmarked=json_obj.get("bookmarked", False),
|
||||
created_at=datetime.fromtimestamp(json_obj.get("created_at", datetime.now(timezone.utc).timestamp())),
|
||||
updated_at=datetime.fromtimestamp(json_obj.get("updated_at", datetime.now(timezone.utc).timestamp())),
|
||||
)
|
||||
|
||||
def to_json(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"api_task_id": self.api_task_id,
|
||||
"api_task_callback": self.api_task_callback,
|
||||
"name": self.name,
|
||||
"type": self.type,
|
||||
"status": self.status,
|
||||
"params": json.loads(self.params),
|
||||
"script_params": base64.b64encode(self.script_params).decode("utf-8"),
|
||||
"priority": self.priority,
|
||||
"result": self.result,
|
||||
"bookmarked": self.bookmarked,
|
||||
"created_at": int(self.created_at.timestamp()),
|
||||
"updated_at": int(self.updated_at.timestamp()),
|
||||
}
|
||||
|
||||
|
||||
class TaskTable(Base):
|
||||
__tablename__ = "task"
|
||||
|
|
@ -101,9 +137,7 @@ class TaskTable(Base):
|
|||
params = Column(Text, nullable=False) # task args
|
||||
script_params = Column(LargeBinary, nullable=False) # script args
|
||||
priority = Column(Integer, nullable=False)
|
||||
status = Column(
|
||||
String(20), nullable=False, default="pending"
|
||||
) # pending, running, done, failed
|
||||
status = Column(String(20), nullable=False, default="pending") # pending, running, done, failed
|
||||
result = Column(Text) # task result
|
||||
bookmarked = Column(Boolean, nullable=True, default=False)
|
||||
created_at = Column(
|
||||
|
|
@ -184,11 +218,7 @@ class TaskManager(BaseTableManager):
|
|||
else:
|
||||
query = query.order_by(TaskTable.bookmarked.asc())
|
||||
|
||||
query = query.order_by(
|
||||
TaskTable.priority.asc()
|
||||
if order == "asc"
|
||||
else TaskTable.priority.desc()
|
||||
)
|
||||
query = query.order_by(TaskTable.priority.asc() if order == "asc" else TaskTable.priority.desc())
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
|
@ -270,9 +300,7 @@ class TaskManager(BaseTableManager):
|
|||
result = session.get(TaskTable, id)
|
||||
if result:
|
||||
if priority == 0:
|
||||
result.priority = (
|
||||
self.__get_min_priority(status=TaskStatus.PENDING) - 1
|
||||
)
|
||||
result.priority = self.__get_min_priority(status=TaskStatus.PENDING) - 1
|
||||
elif priority == -1:
|
||||
result.priority = int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
else:
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -403,6 +403,17 @@ def on_ui_tab(**_kwargs):
|
|||
elem_id="agent_scheduler_action_clear_queue",
|
||||
variant="stop",
|
||||
)
|
||||
gr.Button(
|
||||
"Export",
|
||||
elem_id="agent_scheduler_action_export",
|
||||
variant="secondary",
|
||||
)
|
||||
gr.Button(
|
||||
"Import",
|
||||
elem_id="agent_scheduler_action_import",
|
||||
variant="secondary",
|
||||
)
|
||||
gr.HTML(f'<input type="file" id="agent_scheduler_import_file" style="display: none" accept="application/json" />')
|
||||
|
||||
with gr.Row(elem_classes=["agent_scheduler_filter_container", "flex-row", "ml-auto"]):
|
||||
gr.Textbox(
|
||||
|
|
|
|||
|
|
@ -437,4 +437,4 @@ button.ts-btn-action {
|
|||
#enqueue_keyboard_shortcut_disable {
|
||||
padding-left: 12px !important;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -283,6 +283,18 @@ function initSearchInput(selector: string) {
|
|||
return searchInput;
|
||||
}
|
||||
|
||||
// function initImport(selector: string) {
|
||||
// const importContainer = gradioApp().querySelector<HTMLDivElement>(selector);
|
||||
// if (importContainer == null) {
|
||||
// throw new Error(`Import container '${selector}' not found.`);
|
||||
// }
|
||||
// const importInput = importContainer.getElementsByTagName('input')[0];
|
||||
// if (importInput == null) {
|
||||
// throw new Error('Import input not found.');
|
||||
// }
|
||||
// return importInput;
|
||||
// }
|
||||
|
||||
async function notify(response: ResponseStatus) {
|
||||
if (notyf == null) {
|
||||
const Notyf = await import('notyf');
|
||||
|
|
@ -343,9 +355,9 @@ function showTaskProgress(task_id: string, type: string | undefined, callback: (
|
|||
// monkey patch randomId to return task_id, then call submit to trigger progress
|
||||
window.randomId = () => task_id;
|
||||
if (type === 'txt2img') {
|
||||
(window.submit || window.submit_txt2img)();
|
||||
submit();
|
||||
} else if (type === 'img2img') {
|
||||
window.submit_img2img();
|
||||
submit_img2img();
|
||||
}
|
||||
window.randomId = window.origRandomId;
|
||||
}
|
||||
|
|
@ -591,6 +603,48 @@ function initPendingTab() {
|
|||
}
|
||||
});
|
||||
|
||||
const importButton = gradioApp().querySelector<HTMLButtonElement>(
|
||||
'#agent_scheduler_action_import'
|
||||
)!;
|
||||
const importInput = gradioApp().querySelector<HTMLInputElement>('#agent_scheduler_import_file')!;
|
||||
|
||||
importButton.addEventListener('click', () => {
|
||||
importInput.click();
|
||||
});
|
||||
importInput.addEventListener('change', e => {
|
||||
if (e.target === null) return;
|
||||
|
||||
const files = importInput.files;
|
||||
if (files == null || files.length === 0) return;
|
||||
|
||||
const file = files[0];
|
||||
const reader = new FileReader();
|
||||
reader.onload = () => {
|
||||
const data = reader.result as string;
|
||||
store
|
||||
.importQueue(data)
|
||||
.then(notify)
|
||||
.then(() => {
|
||||
importInput.value = '';
|
||||
store.refresh();
|
||||
});
|
||||
};
|
||||
reader.readAsText(file);
|
||||
});
|
||||
|
||||
const exportButton = gradioApp().querySelector<HTMLButtonElement>(
|
||||
'#agent_scheduler_action_export'
|
||||
)!;
|
||||
exportButton.addEventListener('click', () => {
|
||||
store.exportQueue().then(data => {
|
||||
const dataStr = 'data:text/json;charset=utf-8,' + encodeURIComponent(JSON.stringify(data));
|
||||
const dlAnchorElem = document.createElement('a');
|
||||
dlAnchorElem.setAttribute('href', dataStr);
|
||||
dlAnchorElem.setAttribute('download', `agent-scheduler-${Date.now()}.json`);
|
||||
dlAnchorElem.click();
|
||||
});
|
||||
});
|
||||
|
||||
// watch for queue status change
|
||||
const updateUiState = (state: ReturnType<typeof store.getState>) => {
|
||||
if (state.paused) {
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ type PendingTasksState = {
|
|||
|
||||
type PendingTasksActions = {
|
||||
refresh: () => Promise<void>;
|
||||
exportQueue: () => Promise<string>;
|
||||
importQueue: (str: string) => Promise<ResponseStatus>;
|
||||
pauseQueue: () => Promise<ResponseStatus>;
|
||||
resumeQueue: () => Promise<ResponseStatus>;
|
||||
clearQueue: () => Promise<ResponseStatus>;
|
||||
|
|
@ -32,6 +34,26 @@ export const createPendingTasksStore = (initialState: PendingTasksState) => {
|
|||
.then(response => response.json())
|
||||
.then(setState);
|
||||
},
|
||||
exportQueue: async () => {
|
||||
return fetch('/agent-scheduler/v1/export').then(response => response.json());
|
||||
},
|
||||
importQueue: async (str: string) => {
|
||||
const bodyObj = {
|
||||
content: str,
|
||||
};
|
||||
return fetch(`/agent-scheduler/v1/import`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(bodyObj),
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
setTimeout(() => {
|
||||
actions.refresh();
|
||||
}, 3000);
|
||||
return data;
|
||||
});
|
||||
},
|
||||
pauseQueue: async () => {
|
||||
return fetch('/agent-scheduler/v1/queue/pause', { method: 'POST' })
|
||||
.then(response => response.json())
|
||||
|
|
@ -97,8 +119,9 @@ export const createPendingTasksStore = (initialState: PendingTasksState) => {
|
|||
}).then(response => response.json());
|
||||
},
|
||||
deleteTask: async (id: string) => {
|
||||
return fetch(`/agent-scheduler/v1/task/${id}`, { method: 'DELETE' })
|
||||
.then(response => response.json());
|
||||
return fetch(`/agent-scheduler/v1/task/${id}`, { method: 'DELETE' }).then(response =>
|
||||
response.json()
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue