feat: import/export

pull/181/head
Tung Nguyen 2023-11-14 03:30:12 +07:00
parent 52f8590a03
commit b6965ad05d
7 changed files with 221 additions and 70 deletions

View File

@ -1,6 +1,7 @@
import io
import os
import json
import base64
import requests
import threading
from uuid import uuid4
@ -15,6 +16,7 @@ from fastapi import Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.exceptions import HTTPException
from pydantic import BaseModel
from modules import shared, progress, sd_models, sd_samplers
@ -191,6 +193,39 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
paused=TaskRunner.instance.paused,
)
@app.get("/agent-scheduler/v1/export")
def export_queue(limit: int = 1000, offset: int = 0):
pending_tasks = task_manager.get_tasks(status=TaskStatus.PENDING, limit=limit, offset=offset)
pending_tasks = [Task.from_table(t).to_json() for t in pending_tasks]
return pending_tasks
class StringRequestBody(BaseModel):
content: str
@app.post("/agent-scheduler/v1/import")
def import_queue(queue: StringRequestBody):
try:
objList = json.loads(queue.content)
taskList: List[Task] = []
for obj in objList:
if "id" not in obj or not obj["id"] or obj["id"] == "":
obj["id"] = str(uuid4())
obj["result"] = None
obj["status"] = TaskStatus.PENDING
task = Task.from_json(obj)
taskList.append(task)
for task in taskList:
exists = task_manager.get_task(task.id)
if exists:
task_manager.update_task(task)
else:
task_manager.add_task(task)
return {"success": True, "message": "Queue imported"}
except Exception as e:
print(e)
return {"success": False, "message": "Import Failed"}
@app.get("/agent-scheduler/v1/history", response_model=HistoryResponse, dependencies=deps)
def history_api(status: str = None, limit: int = 20, offset: int = 0):
bookmarked = True if status == "bookmarked" else None

View File

@ -1,6 +1,8 @@
import json
import base64
from enum import Enum
from datetime import datetime, timezone
from typing import Optional, Union, List
from typing import Optional, Union, List, Dict
from sqlalchemy import (
TypeDecorator,
@ -89,6 +91,40 @@ class Task(TaskModel):
bookmarked=self.bookmarked,
)
def from_json(json_obj: Dict):
return Task(
id=json_obj.get("id"),
api_task_id=json_obj.get("api_task_id", None),
api_task_callback=json_obj.get("api_task_callback", None),
name=json_obj.get("name", None),
type=json_obj.get("type"),
status=json_obj.get("status", TaskStatus.PENDING),
params=json.dumps(json_obj.get("params")),
script_params=base64.b64decode(json_obj.get("script_params")),
priority=json_obj.get("priority", int(datetime.now(timezone.utc).timestamp() * 1000)),
result=json_obj.get("result", None),
bookmarked=json_obj.get("bookmarked", False),
created_at=datetime.fromtimestamp(json_obj.get("created_at", datetime.now(timezone.utc).timestamp())),
updated_at=datetime.fromtimestamp(json_obj.get("updated_at", datetime.now(timezone.utc).timestamp())),
)
def to_json(self):
return {
"id": self.id,
"api_task_id": self.api_task_id,
"api_task_callback": self.api_task_callback,
"name": self.name,
"type": self.type,
"status": self.status,
"params": json.loads(self.params),
"script_params": base64.b64encode(self.script_params).decode("utf-8"),
"priority": self.priority,
"result": self.result,
"bookmarked": self.bookmarked,
"created_at": int(self.created_at.timestamp()),
"updated_at": int(self.updated_at.timestamp()),
}
class TaskTable(Base):
__tablename__ = "task"
@ -101,9 +137,7 @@ class TaskTable(Base):
params = Column(Text, nullable=False) # task args
script_params = Column(LargeBinary, nullable=False) # script args
priority = Column(Integer, nullable=False)
status = Column(
String(20), nullable=False, default="pending"
) # pending, running, done, failed
status = Column(String(20), nullable=False, default="pending") # pending, running, done, failed
result = Column(Text) # task result
bookmarked = Column(Boolean, nullable=True, default=False)
created_at = Column(
@ -184,11 +218,7 @@ class TaskManager(BaseTableManager):
else:
query = query.order_by(TaskTable.bookmarked.asc())
query = query.order_by(
TaskTable.priority.asc()
if order == "asc"
else TaskTable.priority.desc()
)
query = query.order_by(TaskTable.priority.asc() if order == "asc" else TaskTable.priority.desc())
if limit:
query = query.limit(limit)
@ -270,9 +300,7 @@ class TaskManager(BaseTableManager):
result = session.get(TaskTable, id)
if result:
if priority == 0:
result.priority = (
self.__get_min_priority(status=TaskStatus.PENDING) - 1
)
result.priority = self.__get_min_priority(status=TaskStatus.PENDING) - 1
elif priority == -1:
result.priority = int(datetime.now(timezone.utc).timestamp() * 1000)
else:

File diff suppressed because one or more lines are too long

View File

@ -403,6 +403,17 @@ def on_ui_tab(**_kwargs):
elem_id="agent_scheduler_action_clear_queue",
variant="stop",
)
gr.Button(
"Export",
elem_id="agent_scheduler_action_export",
variant="secondary",
)
gr.Button(
"Import",
elem_id="agent_scheduler_action_import",
variant="secondary",
)
gr.HTML(f'<input type="file" id="agent_scheduler_import_file" style="display: none" accept="application/json" />')
with gr.Row(elem_classes=["agent_scheduler_filter_container", "flex-row", "ml-auto"]):
gr.Textbox(

View File

@ -437,4 +437,4 @@ button.ts-btn-action {
#enqueue_keyboard_shortcut_disable {
padding-left: 12px !important;
}
}
}

View File

@ -283,6 +283,18 @@ function initSearchInput(selector: string) {
return searchInput;
}
// function initImport(selector: string) {
// const importContainer = gradioApp().querySelector<HTMLDivElement>(selector);
// if (importContainer == null) {
// throw new Error(`Import container '${selector}' not found.`);
// }
// const importInput = importContainer.getElementsByTagName('input')[0];
// if (importInput == null) {
// throw new Error('Import input not found.');
// }
// return importInput;
// }
async function notify(response: ResponseStatus) {
if (notyf == null) {
const Notyf = await import('notyf');
@ -343,9 +355,9 @@ function showTaskProgress(task_id: string, type: string | undefined, callback: (
// monkey patch randomId to return task_id, then call submit to trigger progress
window.randomId = () => task_id;
if (type === 'txt2img') {
(window.submit || window.submit_txt2img)();
submit();
} else if (type === 'img2img') {
window.submit_img2img();
submit_img2img();
}
window.randomId = window.origRandomId;
}
@ -591,6 +603,48 @@ function initPendingTab() {
}
});
const importButton = gradioApp().querySelector<HTMLButtonElement>(
'#agent_scheduler_action_import'
)!;
const importInput = gradioApp().querySelector<HTMLInputElement>('#agent_scheduler_import_file')!;
importButton.addEventListener('click', () => {
importInput.click();
});
importInput.addEventListener('change', e => {
if (e.target === null) return;
const files = importInput.files;
if (files == null || files.length === 0) return;
const file = files[0];
const reader = new FileReader();
reader.onload = () => {
const data = reader.result as string;
store
.importQueue(data)
.then(notify)
.then(() => {
importInput.value = '';
store.refresh();
});
};
reader.readAsText(file);
});
const exportButton = gradioApp().querySelector<HTMLButtonElement>(
'#agent_scheduler_action_export'
)!;
exportButton.addEventListener('click', () => {
store.exportQueue().then(data => {
const dataStr = 'data:text/json;charset=utf-8,' + encodeURIComponent(JSON.stringify(data));
const dlAnchorElem = document.createElement('a');
dlAnchorElem.setAttribute('href', dataStr);
dlAnchorElem.setAttribute('download', `agent-scheduler-${Date.now()}.json`);
dlAnchorElem.click();
});
});
// watch for queue status change
const updateUiState = (state: ReturnType<typeof store.getState>) => {
if (state.paused) {

View File

@ -11,6 +11,8 @@ type PendingTasksState = {
type PendingTasksActions = {
refresh: () => Promise<void>;
exportQueue: () => Promise<string>;
importQueue: (str: string) => Promise<ResponseStatus>;
pauseQueue: () => Promise<ResponseStatus>;
resumeQueue: () => Promise<ResponseStatus>;
clearQueue: () => Promise<ResponseStatus>;
@ -32,6 +34,26 @@ export const createPendingTasksStore = (initialState: PendingTasksState) => {
.then(response => response.json())
.then(setState);
},
exportQueue: async () => {
return fetch('/agent-scheduler/v1/export').then(response => response.json());
},
importQueue: async (str: string) => {
const bodyObj = {
content: str,
};
return fetch(`/agent-scheduler/v1/import`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(bodyObj),
})
.then(response => response.json())
.then(data => {
setTimeout(() => {
actions.refresh();
}, 3000);
return data;
});
},
pauseQueue: async () => {
return fetch('/agent-scheduler/v1/queue/pause', { method: 'POST' })
.then(response => response.json())
@ -97,8 +119,9 @@ export const createPendingTasksStore = (initialState: PendingTasksState) => {
}).then(response => response.json());
},
deleteTask: async (id: string) => {
return fetch(`/agent-scheduler/v1/task/${id}`, { method: 'DELETE' })
.then(response => response.json());
return fetch(`/agent-scheduler/v1/task/${id}`, { method: 'DELETE' }).then(response =>
response.json()
);
},
};