commit 21d150a0bd2e90940a142ac9bb57cdd0fcce4ac6 Author: AutoAgentX Date: Mon May 29 21:31:46 2023 +0700 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6ec740a --- /dev/null +++ b/.gitignore @@ -0,0 +1,40 @@ +__pycache__ +*.ckpt +*.safetensors +*.pth +/ESRGAN/* +/SwinIR/* +/repositories +/venv +/tmp +/model.ckpt +/models/**/* +/GFPGANv1.3.pth +/gfpgan/weights/*.pth +/ui-config.json +/outputs +/config.json +/log +/webui.settings.bat +/embeddings +/styles.csv +/params.txt +/styles.csv.bak +/webui-user.bat +/webui-user.sh +/interrogate +/user.css +/.idea +notification.mp3 +/SwinIR +/textual_inversion +.vscode +/extensions +/test/stdout.txt +/test/stderr.txt +/cache.json +*.sql +*.db +*.sqlite +*.sqlite3 +tailwind.* \ No newline at end of file diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 0000000..e747179 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,9 @@ +{ + "singleQuote": true, + "jsxSingleQuote": false, + "arrowParens": "always", + "trailingComma": "all", + "semi": true, + "tabWidth": 2, + "printWidth": 100 +} diff --git a/README.md b/README.md new file mode 100644 index 0000000..e84841a --- /dev/null +++ b/README.md @@ -0,0 +1,124 @@ +# Agent Scheduler + +Introducing AgentScheduler, an A1111/Vladmandic Stable Diffusion Web UI extension to power up your image generation workflow! + +## Table of Content + +- [Compatibility](#compatibility) +- [Functionality](#functionality--as-of-current-version-) +- [Installation](#installation) + - [Using the built-in extension list](#using-the-built-in-extension-list) + - [Manual clone](#manual-clone) +- [Road Map](#road-map) +- [Contributing](#contributing) +- [License](#license) +- [Disclaimer](#disclaimer) + +--- + +## Compatibility + +This version of AgentScheduler is compatible with: + +- A1111: [commit 5ab7f213](https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/5ab7f213bec2f816f9c5644becb32eb72c8ffb89) +- Vladmandic: [commit 0a46f8ad](https://github.com/vladmandic/automatic/commit/0a46f8ada7751ee993c565198ec4c3327ff43c04) + +> Older versions may not working properly. + +## Functionality [as of current version] + +![Extension Walkthrough](https://user-images.githubusercontent.com/90659883/236373779-43bb9625-d5d9-4450-abc5-93e7af7251fd.jpg) + +1️⃣ Input your usual Prompts & Settings. **Enqueue** to send your current prompts, settings, controlnets to **AgentScheduler**. + +2️⃣ **AgentScheduler** Tab Navigation. + +3️⃣ **Pause** function to stop auto generation. **Refresh** to update. + +4️⃣ See all queued tasks, current image being generated and tasks' associated information. + +5️⃣ **Delete** tasks that you no longer want. Press ▶️ to prioritize selected task. + +6️⃣ **Show/Hide** Column of Information that you want. + +7️⃣ **Drag and Drop** to reorder columns. + +https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/50c74922-b85f-493c-9be8-b8e78f0cd061 + +## Road Map + +To list possible feature upgrades for this extension + +- Sync with GenAI Management Platform **ArtVenture** +- See history of completed jobs (Logs) + +## Installation + +### Using the built-in extension list + +1. Open the Extensions tab +2. Open the "Install From URL" sub-tab +3. Paste the repo url: https://github.com/ArtVentureX/sd-webui-agent-scheduler.git +4. Click "Install" + +![Install](/docs/images/install.png) + +### Manual clone + +```bash +git clone "https://github.com/ArtVentureX/sd-webui-agent-scheduler.git" extensions/agent-scheduler +``` + +(The second argument specifies the name of the folder, you can choose whatever you like). + +## Contributing + +We welcome contributions to the Agent Scheduler Extension project! Please feel free to submit issues, bug reports, and feature requests through the GitHub repository. + +Please give us a ⭐ if you find this extension helpful! + +## License + +This project is licensed under the Apache License 2.0. + +## Disclaimer + +The author(s) of this project are not responsible for any damages or legal issues arising from the use of this software. Users are solely responsible for ensuring that they comply with any applicable laws and regulations when using this software and assume all risks associated with its use. The author(s) are not responsible for any copyright violations or legal issues arising from the use of input or output content. + +--- + +## CRAFTED BY THE PEOPLE BUILDING **ARTVENTURE**, [**ATHERLABS**](https://atherlabs.com/) & [**SIPHER ODYSSEY**](http://playsipher.com/) + +### About ArtVenture (coming soon™️) + +ArtVenture offers powerful collaboration features for Generative AI Image workflows. It is designed to help designers and creative professionals of all levels collaborate more efficiently, unleash their creativity, and have full transparency and tracking over the creation process. + +![ArtVenture Teaser](https://user-images.githubusercontent.com/90659883/236376930-831ac345-e979-4ec5-bece-49e4bc497b79.png) + +![ArtVenture Teaser 2](https://user-images.githubusercontent.com/90659883/236376933-babe9d36-f42f-4c1c-b59a-08be572a1f4c.png) + +### Current Features + +ArtVenture offers the following key features: + +- Seamless Access: available on desktop and mobile +- Multiplayer & Collaborative UX. Strong collaboration features, such as real-time commenting and feedback, version control, and image/file/project sharing. +- Powerful semantic search capabilities. +- Building on shoulders of Giants, leveraging A1111/Vladnmandic and other pioneers, provide collaboration process from Idea (Sketch/Thoughts/Business Request) to Final Results(Images/Copywriting Post/TaskCompleted) in 1 platform +- Automation tooling for certain repeated tasks +- Secure and transparent, leveraging hasing and metadata to track the origin and history of models, loras, images to allow for tracability and ease of collaboration. +- Personalize UX for both beginner and experienced users to quickly remix existing SD images by editing prompts and negative prompts, selecting new training models and output quality as desired. + +### Target Audience + +ArtVenture is designed for the following target audiences: + +- Casual Creators +- Small Design Teams or Freelancers +- Design Agencies & Studios + +## 🎉 Stay Tuned for Updates + +We hope you find this extension to be useful. We will be adding new features and improvements over time as we enhance this extension to support our creative workflows. + +To stay up-to-date with the latest news and updates, be sure to follow us on GitHub and Twitter (coming soon™️). We welcome your feedback and suggestions, and are excited to hear how AgentScheduler can help you streamline your workflow and unleash your creativity! diff --git a/docs/images/enqueue.png b/docs/images/enqueue.png new file mode 100644 index 0000000..1064d11 Binary files /dev/null and b/docs/images/enqueue.png differ diff --git a/docs/images/install.png b/docs/images/install.png new file mode 100644 index 0000000..5cd91d4 Binary files /dev/null and b/docs/images/install.png differ diff --git a/docs/images/manage.png b/docs/images/manage.png new file mode 100644 index 0000000..b3969ad Binary files /dev/null and b/docs/images/manage.png differ diff --git a/docs/images/walkthrough.jpg b/docs/images/walkthrough.jpg new file mode 100644 index 0000000..8a99e79 Binary files /dev/null and b/docs/images/walkthrough.jpg differ diff --git a/install.py b/install.py new file mode 100644 index 0000000..9002916 --- /dev/null +++ b/install.py @@ -0,0 +1,4 @@ +import launch + +if not launch.is_installed("sqlalchemy"): + launch.run_pip("install sqlalchemy", "requirement for task-scheduler") diff --git a/javascript/task_scheduler.js b/javascript/task_scheduler.js new file mode 100644 index 0000000..6678be7 --- /dev/null +++ b/javascript/task_scheduler.js @@ -0,0 +1,492 @@ +(function agent_scheduler_init() { + const head = document.head || document.querySelector('head'); + + window.__loaded_scripts = []; + + const insertStyleTag = (href) => { + const style = document.createElement('link'); + style.rel = 'stylesheet'; + style.href = href; + head.appendChild(style); + }; + + const insertScriptTag = (src, onload) => { + const script = document.createElement('script'); + script.type = 'text/javascript'; + head.appendChild(script); + script.onload = onload; + script.src = src; + }; + + // load ag-grid + insertStyleTag('https://cdn.jsdelivr.net/npm/ag-grid-community@29.3.3/styles/ag-grid.css'); + insertStyleTag( + 'https://cdn.jsdelivr.net/npm/ag-grid-community@29.3.3/styles/ag-theme-alpine.css', + ); + insertScriptTag( + 'https://cdn.jsdelivr.net/npm/ag-grid-community@29.3.3/dist/ag-grid-community.min.noStyle.js', + () => { + window.__loaded_scripts.push('agGrid'); + }, + ); + + // load rxjs + insertScriptTag('https://cdnjs.cloudflare.com/ajax/libs/rxjs/7.8.1/rxjs.umd.min.js', () => { + window.__loaded_scripts.push('rxjs'); + + const observable = new rxjs.Observable((observer) => { + function submit_enqueue() { + var id = randomId(); + var res = create_submit_args(arguments); + res[0] = id; + + document.querySelector('#txt2img_enqueue').innerHTML = 'Queued'; + setTimeout(() => { + document.querySelector('#txt2img_enqueue').innerHTML = 'Enqueue'; + observer.next({ type: 'txt2img', id }); + }, 1000); + + return res; + } + + function submit_enqueue_img2img() { + var id = randomId(); + var res = create_submit_args(arguments); + res[0] = id; + res[1] = get_tab_index('mode_img2img'); + + document.querySelector('#img2img_enqueue').innerHTML = 'Queued'; + setTimeout(() => { + document.querySelector('#img2img_enqueue').innerHTML = 'Enqueue'; + observer.next({ type: 'txt2img', id }); + }, 1000); + + return res; + } + + submit_enqueue.subscribe = observable.subscribe.bind(observable); + submit_enqueue_img2img.subscribe = observable.subscribe.bind(observable); + + window.submit_enqueue = submit_enqueue; + window.submit_enqueue_img2img = submit_enqueue_img2img; + }); + + // register a dummy subscriber + observable.subscribe({ + next: console.log, + error: console.error, + complete: console.log, + }); + }); + + // notyf + insertStyleTag('https://cdn.jsdelivr.net/npm/notyf@3/notyf.min.css'); + insertScriptTag('https://cdn.jsdelivr.net/npm/notyf@3/notyf.min.js', () => { + window.__loaded_scripts.push('notyf'); + }); +})(); + +onUiLoaded(function initTaskScheduler() { + if (window.__loaded_scripts.length < 3) { + return setTimeout(() => { + initTaskScheduler() + }, 200); + } + + // detect black-orange theme + if (document.querySelector('link[href*="black-orange.css"]')) { + document.body.classList.add('black-orange'); + } + + // init notyf + const notyf = new Notyf(); + + // init state + const subject = new rxjs.Subject(); + + const store = { + subject, + init: (initialState) => { + return (store.subject = subject.pipe(rxjs.startWith(initialState))); + }, + subscribe: (callback) => { + return store.subject.pipe(rxjs.pairwise()).subscribe(callback); + }, + refresh: async () => { + return fetch('/agent-scheduler/v1/queue?limit=1000') + .then((response) => response.json()) + .then((data) => { + const pending_tasks = data.pending_tasks.map((item) => ({ + ...item, + params: JSON.parse(item.params), + status: item.id === data.current_task_id ? 'running' : 'pending', + })); + store.subject.next({ + ...data, + pending_tasks, + }); + }); + }, + pauseQueue: async () => { + return fetch('/agent-scheduler/v1/pause', { method: 'POST' }) + .then((response) => response.json()) + .then((data) => { + if (data.success) { + notyf.success(data.message); + } else { + notyf.error(data.message); + } + + return store.refresh(); + }) + }, + resumeQueue: async () => { + return fetch('/agent-scheduler/v1/resume', { method: 'POST' }) + .then((response) => response.json()) + .then((data) => { + if (data.success) { + notyf.success(data.message); + } else { + notyf.error(data.message); + } + + return store.refresh(); + }) + }, + runTask: async (id) => { + return fetch(`/agent-scheduler/v1/run/${id}`, { method: 'POST' }) + .then((response) => response.json()) + .then((data) => { + if (data.success) { + notyf.success(data.message); + } else { + notyf.error(data.message); + } + + return store.refresh(); + }); + }, + deleteTask: async (id) => { + return fetch(`/agent-scheduler/v1/delete/${id}`, { method: 'POST' }) + .then((response) => response.json()) + .then((data) => { + if (data.success) { + notyf.success(data.message); + } else { + notyf.error(data.message); + } + + return store.refresh(); + }); + }, + moveTask: async (id, overId) => { + return fetch(`/agent-scheduler/v1/move/${id}/${overId}`, { method: 'POST' }) + .then((response) => response.json()) + .then((data) => { + if (data.success) { + notyf.success(data.message); + } else { + notyf.error(data.message); + } + + return store.refresh(); + }); + }, + }; + + store.init({ + current_task_id: null, + total_pending_tasks: 0, + pending_tasks: [], + paused: false + }); + + // init actions + const refreshButton = gradioApp().querySelector('#agent_scheduler_action_refresh'); + refreshButton.addEventListener('click', () => { + store.refresh(); + }); + const pauseButton = gradioApp().querySelector('#agent_scheduler_action_pause'); + pauseButton.addEventListener('click', () => { + store.pauseQueue(); + }); + const resumeButton = gradioApp().querySelector('#agent_scheduler_action_resume'); + resumeButton.addEventListener('click', () => { + store.resumeQueue(); + }); + const searchContainer = gradioApp().querySelector('#agent_scheduler_action_search'); + searchContainer.className = "ts-search"; + searchContainer.innerHTML = ` +
+ + + + + +
+ + `; + + + // init grid + + const eGridDiv = gradioApp().querySelector('#agent_scheduler_pending_tasks_grid'); + if (document.querySelector('.dark')) { + eGridDiv.className = 'ag-theme-alpine-dark'; + } + + const deleteIcon = ` + + + + + + + + `; + const cancelIcon = ` + + + + + + `; + const pendingTasksGridOptions = { + domLayout: 'autoHeight', + // each entry here represents one column + columnDefs: [ + { + field: 'id', + headerName: 'Task Id', + minWidth: 240, + maxWidth: 240, + pinned: 'left', + rowDrag: true, + cellClass: ({ data }) => (data.status === 'running' ? 'task-running' : ''), + }, + { field: 'type', headerName: 'Type', minWidth: 80, maxWidth: 80 }, + { + field: 'priority', + headerName: 'Priority', + minWidth: 120, + maxWidth: 120, + filter: 'agNumberColumnFilter', + valueGetter: ({ data }) => data.priority - 1_681_000_000_000, + hide: true, + }, + { + headerName: 'Params', + children: [ + { + field: 'params.checkpoint', + headerName: 'Checkpoint', minWidth: 150, maxWidth: 300, + valueFormatter: ({ value }) => value || 'System' + }, + { + field: 'params.prompt', + headerName: 'Prompt', + width: 400, + minWidth: 200, + wrapText: true, + autoHeight: true, + cellStyle: { 'line-height': '24px', 'padding-top': '8px', 'padding-bottom': '8px' }, + }, + { + field: 'params.negative_prompt', + headerName: 'Negative Prompt', + width: 400, + minWidth: 200, + wrapText: true, + autoHeight: true, + cellStyle: { 'line-height': '24px', 'padding-top': '8px', 'padding-bottom': '8px' }, + }, + { + field: 'params.sampler_name', + headerName: 'Sampler', + width: 150, + minWidth: 150, + }, + { + field: 'params.steps', + headerName: 'Steps', + minWidth: 80, + maxWidth: 80, + filter: 'agNumberColumnFilter', + }, + { + field: 'params.cfg_scale', + headerName: 'CFG Scale', + width: 100, + minWidth: 100, + filter: 'agNumberColumnFilter', + }, + { + field: 'params.size', + headerName: 'Size', + minWidth: 110, + maxWidth: 110, + valueGetter: ({ data }) => `${data.params.width}x${data.params.height}`, + }, + { + field: 'params.batch', + headerName: 'Batching', + minWidth: 100, + maxWidth: 100, + valueGetter: ({ data }) => `${data.params.n_iter}x${data.params.batch_size}`, + }, + ], + }, + { field: 'created_at', headerName: 'Date', minWidth: 200 }, + { + headerName: 'Action', + pinned: 'right', + minWidth: 110, + maxWidth: 110, + resizable: false, + valueGetter: ({ data }) => data.id, + cellRenderer: ({ api, value, data }) => { + const html = ` +
+ + +
+ `; + + const placeholder = document.createElement('div'); + placeholder.innerHTML = html; + const node = placeholder.firstElementChild; + + const btnRun = node.querySelector('button.ts-btn-run'); + btnRun.addEventListener('click', () => { + console.log('run', value); + api.showLoadingOverlay(); + store.runTask(value).then(() => api.hideOverlay()); + }); + + const btnDelete = node.querySelector('button.ts-btn-delete'); + btnDelete.addEventListener('click', () => { + console.log('delete', value); + api.showLoadingOverlay(); + store.deleteTask(value).then(() => api.hideOverlay()); + }); + + return node; + }, + }, + ], + getRowId: ({ data }) => data.id, + + // default col def properties get applied to all columns + defaultColDef: { sortable: false, filter: true, resizable: true, suppressMenu: true }, + + rowSelection: 'single', // allow rows to be selected + animateRows: true, // have rows animate to new positions when sorted + pagination: true, + paginationPageSize: 10, + getContextMenuItems: () => [], + suppressCopyRowsToClipboard: true, + sideBar: { + toolPanels: [ + { + id: 'columns', + labelDefault: 'Columns', + labelKey: 'columns', + iconKey: 'columns', + toolPanel: 'agColumnsToolPanel', + toolPanelParams: { + suppressRowGroups: true, + suppressValues: true, + suppressPivots: true, + suppressPivotMode: true, + }, + }, + { + id: 'filters', + labelDefault: 'Filters', + labelKey: 'filters', + iconKey: 'filter', + toolPanel: 'agFiltersToolPanel', + }, + ], + position: 'right', + }, + + onGridReady: ({ api, columnApi }) => { + // init quick search input + const searchInput = searchContainer.querySelector('input#agent_scheduler_search_input'); + rxjs + .fromEvent(searchInput, 'input') + .pipe(rxjs.debounce(() => rxjs.interval(200))) + .subscribe((e) => { + api.setQuickFilter(e.target.value); + }); + + store.subscribe({ + next: ([_, newState]) => { + api.setRowData(newState.pending_tasks); + if (newState.current_task_id) { + const node = api.getRowNode(newState.current_task_id); + if (node) { + api.refreshCells({ rowNodes: [node], force: true }); + } + } + + columnApi.autoSizeColumns(); + }, + }); + }, + onRowDragEnd: ({ api, node, overNode }) => { + const id = node.data.id; + const overId = overNode.data.id; + + api.showLoadingOverlay(); + store.moveTask(id, overId).then(() => api.hideOverlay()); + }, + }; + new agGrid.Grid(eGridDiv, pendingTasksGridOptions); + + // watch for current task id change + const onTaskIdChange = (id) => { + if (id) { + requestProgress( + id, + gradioApp().getElementById('agent_scheduler_current_task_progress'), + gradioApp().getElementById('agent_scheduler_current_task_images'), + () => { + setTimeout(() => { + store.refresh(); + }, 1000); + }, + ); + } + }; + store.subscribe({ + next: ([prev, curr]) => { + if (prev.current_task_id !== curr.current_task_id) { + onTaskIdChange(curr.current_task_id); + } + if (curr.paused) { + pauseButton.classList.add('hide'); + resumeButton.classList.remove('hide'); + } else { + pauseButton.classList.remove('hide'); + resumeButton.classList.add('hide'); + } + }, + }); + + // watch for task submission + window.submit_enqueue.subscribe({ + next: () => store.refresh(), + }); + + // refresh the state + store.refresh(); +}); diff --git a/scripts/api.py b/scripts/api.py new file mode 100644 index 0000000..7f2c283 --- /dev/null +++ b/scripts/api.py @@ -0,0 +1,123 @@ +import json +import threading +from gradio.routes import App + +import modules.shared as shared +from modules import progress, script_callbacks, sd_samplers + +from scripts.db import TaskStatus, AppStateKey, task_manager, state_manager +from scripts.models import QueueStatusResponse +from scripts.task_runner import TaskRunner, get_instance +from scripts.helpers import log + +task_runner: TaskRunner = None + + +def regsiter_apis(app: App): + log.info("[AgentScheduler] Registering APIs") + + @app.get("/agent-scheduler/v1/queue", response_model=QueueStatusResponse) + def queue_status_api(limit: int = 20, offset: int = 0): + current_task_id = progress.current_task + total_pending_tasks = total_pending_tasks = task_manager.count_tasks( + status="pending" + ) + pending_tasks = task_manager.get_tasks( + status=TaskStatus.PENDING, limit=limit, offset=offset + ) + for task in pending_tasks: + task_args = TaskRunner.instance.parse_task_args( + task.params, task.script_params, deserialization=False + ) + named_args = task_args["named_args"] + named_args["checkpoint"] = task_args["checkpoint"] + sampler_index = named_args.get("sampler_index", None) + if sampler_index is not None: + named_args["sampler_name"] = sd_samplers.samplers[ + named_args["sampler_index"] + ].name + task.params = json.dumps(named_args) + + return QueueStatusResponse( + current_task_id=current_task_id, + pending_tasks=pending_tasks, + total_pending_tasks=total_pending_tasks, + paused=TaskRunner.instance.paused, + ) + + @app.post("/agent-scheduler/v1/run/{id}") + def run_task(id: str): + if progress.current_task is not None: + if progress.current_task == id: + return {"success": False, "message": f"Task {id} is already running"} + else: + # move task up in queue + task_manager.prioritize_task(id, 0) + return { + "success": True, + "message": f"Task {id} is scheduled to run next", + } + else: + # run task + task = task_manager.get_task(id) + current_thread = threading.Thread( + target=TaskRunner.instance.execute_task, + args=( + task, + lambda: None, + ), + ) + current_thread.daemon = True + current_thread.start() + + return {"success": True, "message": f"Task {id} is executing"} + + @app.post("/agent-scheduler/v1/delete/{id}") + def delete_task(id: str): + if progress.current_task == id: + shared.state.interrupt() + return {"success": True, "message": f"Task {id} is interrupted"} + + task_manager.delete_task(id) + return {"success": True, "message": f"Task {id} is deleted"} + + @app.post("/agent-scheduler/v1/move/{id}/{over_id}") + def move_task(id: str, over_id: str): + task = task_manager.get_task(id) + if task is None: + return {"success": False, "message": f"Task {id} not found"} + + if over_id == "top": + task_manager.prioritize_task(id, 0) + return {"success": True, "message": f"Task {id} is moved to top"} + elif over_id == "bottom": + task_manager.prioritize_task(id, -1) + return {"success": True, "message": f"Task {id} is moved to bottom"} + else: + over_task = task_manager.get_task(over_id) + if over_task is None: + return {"success": False, "message": f"Task {over_id} not found"} + + task_manager.prioritize_task(id, over_task.priority) + return {"success": True, "message": f"Task {id} is moved"} + + @app.post("/agent-scheduler/v1/pause") + def pause_queue(): + state_manager.set_value(AppStateKey.QueueState, "paused") + return {"success": True, "message": f"Queue is paused"} + + @app.post("/agent-scheduler/v1/resume") + def resume_queue(): + state_manager.set_value(AppStateKey.QueueState, "running") + TaskRunner.instance.execute_pending_tasks_threading() + return {"success": True, "message": f"Queue is resumed"} + + +def on_app_started(block, app: App): + global task_runner + task_runner = get_instance(block) + + regsiter_apis(app) + + +script_callbacks.on_app_started(on_app_started) diff --git a/scripts/db/__init__.py b/scripts/db/__init__.py new file mode 100644 index 0000000..85be6ff --- /dev/null +++ b/scripts/db/__init__.py @@ -0,0 +1,74 @@ +from pathlib import Path +from sqlalchemy import create_engine, inspect, text, String, Text + +from .base import Base, metadata, db_file +from .app_state import AppStateKey, AppState, AppStateManager +from .task import TaskStatus, Task, TaskManager + +version = "2" + +state_manager = AppStateManager() +task_manager = TaskManager() + + +def init(): + engine = create_engine(f"sqlite:///{db_file}") + + # check if database exists + if not Path(db_file).exists(): + # create database + metadata.create_all(engine) + + state_manager.set_value(AppStateKey.Version, version) + # check if app state exists + if state_manager.get_value(AppStateKey.QueueState) is None: + # create app state + state_manager.set_value(AppStateKey.QueueState, "running") + + inspector = inspect(engine) + with engine.connect() as conn: + # check if table task has column result and add it if not + task_columns = inspector.get_columns("task") + if not any(col["name"] == "result" for col in task_columns): + conn.execute(text("ALTER TABLE task ADD COLUMN result TEXT")) + + params_column = next(col for col in task_columns if col["name"] == "params") + if version > "1" and not isinstance(params_column["type"], Text): + transaction = conn.begin() + conn.execute( + text( + """ + CREATE TABLE task_temp ( + id VARCHAR(64) NOT NULL, + type VARCHAR(20) NOT NULL, + params TEXT NOT NULL, + script_params BLOB NOT NULL, + priority INTEGER NOT NULL, + status VARCHAR(20) NOT NULL, + created_at DATETIME DEFAULT (datetime('now')) NOT NULL, + updated_at DATETIME DEFAULT (datetime('now')) NOT NULL, + result TEXT, + PRIMARY KEY (id) + )""" + ) + ) + conn.execute(text("INSERT INTO task_temp SELECT * FROM task")) + conn.execute(text("DROP TABLE task")) + conn.execute(text("ALTER TABLE task_temp RENAME TO task")) + transaction.commit() + + conn.close() + + +__all__ = [ + "init", + "Base", + "metadata", + "db_file", + "AppStateKey", + "AppState", + "TaskStatus", + "Task", + "task_manager", + "state_manager", +] diff --git a/scripts/db/app_state.py b/scripts/db/app_state.py new file mode 100644 index 0000000..69da55c --- /dev/null +++ b/scripts/db/app_state.py @@ -0,0 +1,79 @@ +from enum import Enum + +from sqlalchemy import Column, String +from sqlalchemy.orm import Session + +from .base import BaseTableManager, Base + + +class AppStateKey(str, Enum): + Version = "version" + QueueState = "queue_state" # paused or running + + +class AppState: + def __init__(self, key: str, value: str): + self.key: str = key + self.value: str = value + + @staticmethod + def from_table(table: "AppStateTable"): + return AppState(table.key, table.value) + + def to_table(self): + return AppStateTable(key=self.key, value=self.value) + + +class AppStateTable(Base): + __tablename__ = "app_state" + + key = Column(String(64), primary_key=True) + value = Column(String(255), nullable=True) + + def __repr__(self): + return f"AppState(key={self.key!r}, value={self.value!r})" + + +class AppStateManager(BaseTableManager): + def get_value(self, key: str) -> str | None: + session = Session(self.engine) + try: + result = session.get(AppStateTable, key) + if result: + return result.value + else: + return None + except Exception as e: + print(f"Exception getting value from database: {e}") + raise e + finally: + session.close() + + def set_value(self, key: str, value: str): + session = Session(self.engine) + try: + result = session.get(AppStateTable, key) + if result: + result.value = value + else: + result = AppStateTable(key=key, value=value) + session.add(result) + session.commit() + except Exception as e: + print(f"Exception setting value in database: {e}") + raise e + finally: + session.close() + + def delete_value(self, key: str): + session = Session(self.engine) + try: + result = session.get(AppStateTable, key) + if result: + session.delete(result) + session.commit() + except Exception as e: + print(f"Exception deleting value from database: {e}") + raise e + finally: + session.close() diff --git a/scripts/db/base.py b/scripts/db/base.py new file mode 100644 index 0000000..285d4af --- /dev/null +++ b/scripts/db/base.py @@ -0,0 +1,30 @@ +import os + +from sqlalchemy import create_engine +from sqlalchemy.schema import MetaData +from sqlalchemy.orm import declarative_base + +from modules import scripts + + +Base = declarative_base() +metadata: MetaData = Base.metadata + +db_file = os.path.join(scripts.basedir(), "task_scheduler.sqlite3") + + +class BaseTableManager: + def __init__(self, engine = None): + # Get the db connection object, making the file and tables if needed. + try: + self.engine = engine if engine else create_engine(f"sqlite:///{db_file}") + except Exception as e: + print(f"Exception connecting to database: {e}") + raise e + + def get_engine(self): + return self.engine + + # Commit and close the database connection. + def quit(self): + self.engine.dispose() diff --git a/scripts/db/task.py b/scripts/db/task.py new file mode 100644 index 0000000..4156d62 --- /dev/null +++ b/scripts/db/task.py @@ -0,0 +1,289 @@ +from enum import Enum +from datetime import datetime +from typing import Optional + +from sqlalchemy import Column, String, Text, Integer, DateTime, LargeBinary, text, func +from sqlalchemy.orm import Session + +from .base import BaseTableManager, Base +from ..models import TaskModel + + +class TaskStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + DONE = "done" + FAILED = "failed" + + +class Task(TaskModel): + script_params: bytes = None + + def __init__( + self, + id: str = "", + type: str = "unknown", + params: str = "", + script_params: bytes = b"", + priority: int = None, + status: str = TaskStatus.PENDING.value, + result: str = None, + created_at: Optional[datetime] = None, + updated_at: Optional[datetime] = None, + ): + priority = priority if priority else int(datetime.utcnow().timestamp() * 1000) + + super().__init__( + id=id, + type=type, + params=params, + status=status, + priority=priority, + result=result, + created_at=created_at, + updated_at=created_at, + ) + self.id: str = id + self.type: str = type + self.params: str = params + self.script_params: bytes = script_params + self.priority: int = priority + self.status: str = status + self.result: str = result + self.created_at: datetime = created_at + self.updated_at: datetime = updated_at + + class Config(TaskModel.__config__): + exclude = ["script_params"] + + @staticmethod + def from_table(table: "TaskTable"): + return Task( + id=table.id, + type=table.type, + params=table.params, + script_params=table.script_params, + priority=table.priority, + status=table.status, + created_at=table.created_at, + updated_at=table.updated_at, + ) + + def to_table(self): + return TaskTable( + id=self.id, + type=self.type, + params=self.params, + script_params=self.script_params, + priority=self.priority, + status=self.status, + ) + + +class TaskTable(Base): + __tablename__ = "task" + + id = Column(String(64), primary_key=True) + type = Column(String(20), nullable=False) # txt2img or img2txt + params = Column(Text, nullable=False) # task args + script_params = Column(LargeBinary, nullable=False) # script args + priority = Column(Integer, nullable=False, default=datetime.now) + status = Column( + String(20), nullable=False, default="pending" + ) # pending, running, done, failed + result = Column(Text) # task result + created_at = Column( + DateTime, + nullable=False, + server_default=text("(datetime('now'))"), + ) + updated_at = Column( + DateTime, + nullable=False, + server_default=text("(datetime('now'))"), + onupdate=text("(datetime('now'))"), + ) + + def __repr__(self): + return f"Task(id={self.id!r}, type={self.type!r}, params={self.params!r}, status={self.status!r}, created_at={self.created_at!r})" + + +class TaskManager(BaseTableManager): + def get_task(self, id: str) -> TaskTable | None: + session = Session(self.engine) + try: + task = session.get(TaskTable, id) + + return Task.from_table(task) if task else None + except Exception as e: + print(f"Exception getting task from database: {e}") + raise e + finally: + session.close() + + def get_tasks( + self, + type: str = None, + status: str = None, + limit: int = None, + offset: int = None, + ) -> list[TaskTable]: + session = Session(self.engine) + try: + query = session.query(TaskTable) + if type: + query = query.filter(TaskTable.type == type) + + if status: + query = query.filter(TaskTable.status == status) + + query = query.order_by(TaskTable.priority.asc()).order_by( + TaskTable.created_at.asc() + ) + + if limit: + query = query.limit(limit) + + if offset: + query = query.offset(offset) + + all = query.all() + return [Task.from_table(t) for t in all] + except Exception as e: + print(f"Exception getting tasks from database: {e}") + raise e + finally: + session.close() + + def count_tasks( + self, + type: str = None, + status: str = None, + ) -> int: + session = Session(self.engine) + try: + query = session.query(TaskTable) + if type: + query = query.filter(TaskTable.type == type) + + if status: + query = query.filter(TaskTable.status == status) + + return query.count() + except Exception as e: + print(f"Exception counting tasks from database: {e}") + raise e + finally: + session.close() + + def add_task(self, task: Task) -> TaskTable: + session = Session(self.engine) + try: + result = task.to_table() + session.add(result) + session.commit() + return result + except Exception as e: + print(f"Exception adding task to database: {e}") + raise e + finally: + session.close() + + def update_task(self, id: str, status: str, result=None) -> TaskTable: + session = Session(self.engine) + try: + task = session.get(TaskTable, id) + if task: + task.status = status + task.result = result + session.commit() + return task + else: + raise Exception(f"Task with id {id} not found") + except Exception as e: + print(f"Exception updating task in database: {e}") + raise e + finally: + session.close() + + def prioritize_task(self, id: str, priority: int) -> TaskTable: + """0 means move to top, -1 means move to bottom, otherwise set the exact priority""" + + session = Session(self.engine) + try: + result = session.get(TaskTable, id) + if result: + if priority == 0: + result.priority = self.__get_min_priority() - 1 + elif priority == -1: + result.priority = int(datetime.utcnow().timestamp() * 1000) + else: + self.__move_tasks_down(priority) + session.execute(text("SELECT 1")) + result.priority = priority + + session.commit() + return result + else: + raise Exception(f"Task with id {id} not found") + except Exception as e: + print(f"Exception updating task in database: {e}") + raise e + finally: + session.close() + + def delete_task(self, id: str): + session = Session(self.engine) + try: + result = session.get(TaskTable, id) + if result: + session.delete(result) + session.commit() + else: + raise Exception(f"Task with id {id} not found") + except Exception as e: + print(f"Exception deleting task from database: {e}") + raise e + finally: + session.close() + + def delete_tasks_before(self, before: datetime, all: bool = False): + session = Session(self.engine) + try: + query = session.query(TaskTable).filter(TaskTable.created_at < before) + if not all: + query = query.filter( + TaskTable.status.in_([TaskStatus.DONE, TaskStatus.FAILED]) + ) + + query.delete() + session.commit() + except Exception as e: + print(f"Exception deleting tasks from database: {e}") + raise e + finally: + session.close() + + def __get_min_priority(self) -> int: + session = Session(self.engine) + try: + min_priority = session.query(func.min(TaskTable.priority)).scalar() + return min_priority if min_priority else 0 + except Exception as e: + print(f"Exception getting min priority from database: {e}") + raise e + finally: + session.close() + + def __move_tasks_down(self, priority: int): + session = Session(self.engine) + try: + session.query(TaskTable).filter(TaskTable.priority >= priority).update( + {TaskTable.priority: TaskTable.priority + 1} + ) + session.commit() + except Exception as e: + print(f"Exception moving tasks down in database: {e}") + raise e + finally: + session.close() diff --git a/scripts/helpers.py b/scripts/helpers.py new file mode 100644 index 0000000..be62351 --- /dev/null +++ b/scripts/helpers.py @@ -0,0 +1,80 @@ +import abc +import logging + +import gradio as gr +from gradio.blocks import Block, BlockContext + +if not logging.getLogger().hasHandlers(): + # Logging is not set up + logging.basicConfig(level=logging.INFO, format='%(message)s') + +log = logging.getLogger("sd") + + +class Singleton(abc.ABCMeta, type): + """ + Singleton metaclass for ensuring only one instance of a class. + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + """Call method for the singleton metaclass.""" + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +def compare_components_with_ids(components: list[Block], ids: list[int]): + return len(components) == len(ids) and all( + component._id == _id for component, _id in zip(components, ids) + ) + + +def get_component_by_elem_id(root: Block, elem_id: str): + if root.elem_id == elem_id: + return root + + elem = None + if isinstance(root, BlockContext): + for block in root.children: + elem = get_component_by_elem_id(block, elem_id) + if elem is not None: + break + + return elem + + +def get_components_by_ids(root: Block, ids: list[int]): + components: list[Block] = [] + + if root._id in ids: + components.append(root) + ids = [_id for _id in ids if _id != root._id] + + if isinstance(root, BlockContext): + for block in root.children: + components.extend(get_components_by_ids(block, ids)) + + return components + + +def detect_control_net(root: gr.Blocks, submit: gr.Button): + UiControlNetUnit = None + + dependencies: list[dict] = [ + x + for x in root.dependencies + if x["trigger"] == "click" and submit._id in x["targets"] + ] + for d in dependencies: + if len(d["outputs"]) == 1: + outputs = get_components_by_ids(root, d["outputs"]) + output = outputs[0] + if ( + isinstance(output, gr.State) + and type(output.value).__name__ == "UiControlNetUnit" + ): + UiControlNetUnit = type(output.value) + + return UiControlNetUnit diff --git a/scripts/models.py b/scripts/models.py new file mode 100644 index 0000000..7a1df37 --- /dev/null +++ b/scripts/models.py @@ -0,0 +1,41 @@ +from datetime import datetime, timezone + +from typing import Optional, List + +from pydantic import BaseModel, Field + + +def convert_datetime_to_iso_8601_with_z_suffix(dt: datetime) -> str: + return dt.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z' if dt else None + + +def transform_to_utc_datetime(dt: datetime) -> datetime: + return dt.astimezone(tz=timezone.utc) + + +class QueueStatusAPI(BaseModel): + limit: Optional[int] = Field(title="Limit", description="The maximum number of tasks to return", default=20) + offset: Optional[int] = Field(title="Offset", description="The offset of the tasks to return", default=0) + + +class TaskModel(BaseModel): + id: str = Field(title="Task Id") + type: str = Field(title="Task Type", description="Either txt2img or img2img") + status: str = Field(title="Task Status", description="Either pending, running, done or failed") + params: str = Field(title="Task Parameters", description="The parameters of the task in JSON format") + priority: int = Field(title="Task Priority") + result: Optional[str] = Field(title="Task Result", description="The result of the task in JSON format") + created_at: Optional[datetime] = Field(title="Task Created At", description="The time when the task was created", default=None) + updated_at: Optional[datetime] = Field(title="Task Updated At", description="The time when the task was updated", default=None) + + class Config: + json_encoders = { + # custom output conversion for datetime + datetime: convert_datetime_to_iso_8601_with_z_suffix + } + +class QueueStatusResponse(BaseModel): + current_task_id: Optional[str] = Field(title="Current Task Id", description="The on progress task id") + pending_tasks: List[TaskModel] = Field(title="Pending Tasks", description="The pending tasks in the queue") + total_pending_tasks: int = Field(title="Queue length", description="The total pending tasks in the queue") + paused: bool = Field(title="Paused", description="Whether the queue is paused") diff --git a/scripts/task_runner.py b/scripts/task_runner.py new file mode 100644 index 0000000..d0e68d6 --- /dev/null +++ b/scripts/task_runner.py @@ -0,0 +1,498 @@ +import io +import sys +import json +import time +import zlib +import base64 +import pickle +import inspect +import traceback +import threading +import numpy as np + +from datetime import datetime, timedelta +from enum import Enum +from PIL import Image, PngImagePlugin +from typing import Any, Callable, Union +from fastapi import FastAPI + +from modules import progress, shared, script_callbacks +from modules.call_queue import queue_lock, wrap_gradio_call +from modules.txt2img import txt2img +from modules.img2img import img2img +from modules.api.api import Api, encode_pil_to_base64 +from modules.api.models import ( + StableDiffusionTxt2ImgProcessingAPI, + StableDiffusionImg2ImgProcessingAPI, +) + +from scripts.db import TaskStatus, AppStateKey, Task, task_manager, state_manager +from scripts.helpers import log, detect_control_net, get_component_by_elem_id + +img2img_image_args_by_mode: dict[int, list[list[str]]] = { + 0: [["init_img"]], + 1: [["sketch"]], + 2: [["init_img_with_mask", "image"], ["init_img_with_mask", "mask"]], + 3: [["inpaint_color_sketch"], ["inpaint_color_sketch_orig"]], + 4: [["init_img_inpaint"], ["init_mask_inpaint"]], +} + + +class TaskRunner: + instance = None + + def __init__(self, UiControlNetUnit=None): + self.UiControlNetUnit = UiControlNetUnit + + self.__total_pending_tasks: int = 0 + self.__current_thread: threading.Thread = None + self.__api = Api(FastAPI(), queue_lock) + + self.__saved_images_path: list[str] = [] + script_callbacks.on_image_saved(self.__on_image_saved) + + self.script_callbacks = { + "task_registered": [], + "task_started": [], + "task_finished": [], + "task_cleared": [], + } + + # Mark this to True when reload UI + self.dispose = False + + if TaskRunner.instance is not None: + raise Exception("TaskRunner instance already exists") + TaskRunner.instance = self + + @property + def current_task_id(self) -> Union[str, None]: + return progress.current_task + + @property + def is_executing_task(self) -> bool: + return self.__current_thread and self.__current_thread.is_alive() + + @property + def paused(self) -> bool: + return state_manager.get_value(AppStateKey.QueueState) == "paused" + + def __serialize_image(self, image): + if isinstance(image, np.ndarray): + shape = image.shape + data = base64.b64encode(zlib.compress(image.tobytes())).decode() + return {"shape": shape, "data": data, "cls": "ndarray"} + elif isinstance(image, Image.Image): + size = image.size + mode = image.mode + data = base64.b64encode(zlib.compress(image.tobytes())).decode() + return { + "size": size, + "mode": mode, + "data": data, + "cls": "Image", + } + else: + return image + + def __deserialize_image(self, image_str): + if isinstance(image_str, dict) and image_str.get("cls", None): + cls = image_str["cls"] + data = zlib.decompress(base64.b64decode(image_str["data"])) + + if cls == "ndarray": + shape = tuple(image_str["shape"]) + image = np.frombuffer(data, dtype=np.uint8) + return image.reshape(shape) + else: + size = tuple(image_str["size"]) + mode = image_str["mode"] + return Image.frombytes(mode, size, data) + else: + return image_str + + def __serialize_img2img_images(self, args: dict, image_args: list): + for keys in image_args: + if len(keys) == 1: + image = args.get(keys[0], None) + args[keys[0]] = self.__serialize_image(image) + else: + value = args.get(keys[0], {}) + image = value.get(keys[1], None) + value[keys[1]] = self.__serialize_image(image) + args[keys[0]] = value + + def __deserialize_img2img_images(self, args: dict, image_args: list): + for keys in image_args: + if len(keys) == 1: + image = args.get(keys[0], None) + args[keys[0]] = self.__deserialize_image(image) + else: + value = args.get(keys[0], {}) + image = value.get(keys[1], None) + value[keys[1]] = self.__deserialize_image(image) + args[keys[0]] = value + + def __serialize_ui_task_args(self, is_img2img: bool, *args, checkpoint: str = None): + args_name = [] + if is_img2img: + args_name = inspect.getfullargspec(img2img).args + else: + args_name = inspect.getfullargspec(txt2img).args + + args = list(args) + named_args = dict(zip(args_name, args[0 : len(args_name)])) + script_args = args[len(args_name) :] + if checkpoint: + override_settings_texts = named_args.get("override_settings_texts", []) + override_settings_texts.append("Model hash: " + checkpoint) + named_args["override_settings_texts"] = override_settings_texts + + # loop through named_args and serialize images + if is_img2img: + for mode, image_args in img2img_image_args_by_mode.items(): + if mode == named_args["mode"]: + self.__serialize_img2img_images(named_args, image_args) + else: + # set None to unused image args to save space + for keys in image_args: + named_args[keys[0]] = None + + # loop through script_args and serialize controlnets + if self.UiControlNetUnit is not None: + for i, a in enumerate(script_args): + if isinstance(a, self.UiControlNetUnit): + script_args[i] = a.__dict__ + script_args[i]["is_cnet"] = True + for k, v in script_args[i].items(): + if k == "image" and v is not None: + script_args[i][k] = { + "image": self.__serialize_image(v["image"]), + "mask": self.__serialize_image(v["mask"]), + } + if isinstance(v, Enum): + script_args[i][k] = v.value + + return json.dumps( + { + "args": named_args, + "script_args": script_args, + "checkpoint": checkpoint, + "is_ui": True, + "is_img2img": is_img2img, + } + ) + + def __serialize_api_task_args( + self, is_img2img: bool, script_args: list = [], **named_args + ): + override_settings = named_args.get("override_settings", {}) + checkpoint = override_settings.get("sd_model_checkpoint", None) + + return json.dumps( + { + "args": named_args, + "script_args": script_args, + "checkpoint": checkpoint, + "is_ui": False, + "is_img2img": is_img2img, + } + ) + + def __deserialize_ui_task_args( + self, is_img2img: bool, named_args: dict, script_args: list + ): + # loop through image_args and deserialize images + if is_img2img: + for mode, image_args in img2img_image_args_by_mode.items(): + if mode == named_args["mode"]: + self.__deserialize_img2img_images(named_args, image_args) + + # loop through script_args and deserialize controlnets + if self.UiControlNetUnit is not None: + for i, arg in enumerate(script_args): + if isinstance(arg, dict) and arg.get("is_cnet", False): + arg.pop("is_cnet") + for k, v in arg.items(): + if k == "image" and v is not None: + arg[k] = { + "image": self.__deserialize_image(v["image"]), + "mask": self.__deserialize_image(v["mask"]), + } + + def parse_task_args( + self, params: str, script_params: bytes, deserialization: bool = True + ): + parsed: dict[str, Any] = json.loads(params) + + is_ui = parsed.get("is_ui", True) + is_img2img = parsed.get("is_img2img", None) + checkpoint = parsed.get("checkpoint", None) + named_args: dict[str, Any] = parsed["args"] + script_args: list[Any] = ( + parsed["script_args"] + if "script_args" in parsed + else pickle.loads(script_params) + ) + + if is_ui and deserialization: + self.__deserialize_ui_task_args(is_img2img, named_args, script_args) + + args = list(named_args.values()) + script_args + + return { + "args": args, + "named_args": named_args, + "script_args": script_args, + "checkpoint": checkpoint, + "is_ui": is_ui, + } + + def register_ui_task( + self, task_id: str, is_img2img: bool, *args, checkpoint: str = None + ): + progress.add_task_to_queue(task_id) + + params = self.__serialize_ui_task_args(is_img2img, *args, checkpoint=checkpoint) + + task_type = "img2img" if is_img2img else "txt2img" + task_manager.add_task(Task(id=task_id, type=task_type, params=params)) + + self.__run_callbacks( + "task_registered", task_id, is_img2img=is_img2img, is_ui=True, args=params + ) + self.__total_pending_tasks += 1 + + def register_api_task(self, task_id: str, is_img2img: bool, args: dict): + progress.add_task_to_queue(task_id) + + params = self.__serialize_api_task_args(is_img2img, **args) + + task_type = "img2img" if is_img2img else "txt2img" + task_manager.add_task(Task(id=task_id, type=task_type, params=params)) + + self.__run_callbacks( + "task_registered", task_id, is_img2img=is_img2img, is_ui=False, args=params + ) + self.__total_pending_tasks += 1 + + def execute_task(self, task: Task, get_next_task: Callable): + while True: + if self.dispose: + sys.exit(0) + + if progress.current_task is None: + task_id = task.id + is_img2img = task.type == "img2img" + log.info(f"[AgentScheduler] Executing task {task_id}") + + task_args = self.parse_task_args( + task.params, + task.script_params, + ) + task_meta = {"is_img2img": is_img2img, "is_ui": task_args["is_ui"]} + + self.__saved_images_path = [] + self.__run_callbacks("task_started", task_id, **task_meta) + res = self.__execute_task(task_id, is_img2img, task_args) + if not res or isinstance(res, Exception): + task_manager.update_task(id=task_id, status=TaskStatus.FAILED) + self.__run_callbacks( + "task_finished", task_id, status=TaskStatus.FAILED, **task_meta + ) + else: + res = json.loads(res) + log.info(f"\n[AgentScheduler] Task {task.id} done") + infotexts = [] + for line in res["infotexts"]: + infotexts.extend(line.split("\n")) + infotexts[0] = f"Prompt: {infotexts[0]}" + log.info("\n".join(["** " + text for text in infotexts])) + + result = { + "images": self.__saved_images_path.copy(), + "infotexts": infotexts, + } + task_manager.update_task( + id=task_id, + status=TaskStatus.DONE, + result=json.dumps(result), + ) + self.__run_callbacks( + "task_finished", + task_id, + status=TaskStatus.DONE, + result=result, + **task_meta, + ) + + self.__saved_images_path = [] + else: + time.sleep(2) + continue + + task = get_next_task() + if not task: + sys.exit(0) + + def execute_pending_tasks_threading(self): + if self.paused: + log.info("[AgentScheduler] Runner is paused") + return + + if self.is_executing_task: + log.info("[AgentScheduler] Runner already started") + return + + pending_task = self.__get_pending_task() + if pending_task: + # Start the infinite loop in a separate thread + self.__current_thread = threading.Thread( + target=self.execute_task, + args=( + pending_task, + self.__get_pending_task, + ), + ) + self.__current_thread.daemon = True + self.__current_thread.start() + + def get_task_info(self, task: Task) -> list[Any]: + task_args = self.parse_task_args( + task.params, + task.script_params, + ) + + return [ + task.id, + task.type, + json.dumps(task_args["named_args"]), + task.created_at.strftime("%Y-%m-%d %H:%M:%S"), + ] + + def __execute_task(self, task_id: str, is_img2img: bool, task_args: dict): + if task_args["is_ui"]: + return self.__execute_ui_task(task_id, is_img2img, *task_args["args"]) + else: + return self.__execute_api_task( + task_id, + is_img2img, + script_args=task_args["script_args"], + **task_args["named_args"], + ) + + def __execute_ui_task(self, task_id: str, is_img2img: bool, *args): + func = wrap_gradio_call(img2img if is_img2img else txt2img, add_stats=True) + + with queue_lock: + shared.state.begin() + progress.start_task(task_id) + + res = None + try: + result = func(*args) + res = result[1] + except Exception as e: + log.error(f"[AgentScheduler] Task {task_id} failed: {e}") + log.error(traceback.format_exc()) + res = e + finally: + progress.finish_task(task_id) + + shared.state.end() + + return res + + def __execute_api_task(self, task_id: str, is_img2img: bool, **kwargs): + progress.start_task(task_id) + + res = None + try: + result = ( + self.__api.img2imgapi(StableDiffusionImg2ImgProcessingAPI(**kwargs)) + if is_img2img + else self.__api.text2imgapi( + StableDiffusionTxt2ImgProcessingAPI(**kwargs) + ) + ) + res = result.info + except Exception as e: + log.error(f"[AgentScheduler] Task {task_id} failed: {e}") + log.error(traceback.format_exc()) + res = e + finally: + progress.finish_task(task_id) + + return res + + def __get_pending_task(self): + if self.dispose: + return None + + # delete task that are 7 days old + task_manager.delete_tasks_before(datetime.now() - timedelta(days=7)) + + self.__total_pending_tasks = task_manager.count_tasks(status="pending") + + # get more task if needed + if self.__total_pending_tasks > 0: + log.info( + f"[AgentScheduler] Total pending tasks: {self.__total_pending_tasks}" + ) + pending_tasks = task_manager.get_tasks(status="pending", limit=1) + if len(pending_tasks) > 0: + return pending_tasks[0] + else: + log.info("[AgentScheduler] Task queue is empty") + self.__run_callbacks("task_cleared") + + def __on_image_saved(self, data: script_callbacks.ImageSaveParams): + self.__saved_images_path.append(data.filename) + + def on_task_registered(self, callback: Callable): + """Callback when a task is registered + + Callback signature: callback(task_id: str, is_img2img: bool, is_ui: bool, args: dict) + """ + + self.script_callbacks["task_registered"].append(callback) + + def on_task_started(self, callback: Callable): + """Callback when a task is started + + Callback signature: callback(task_id: str, is_img2img: bool, is_ui: bool) + """ + + self.script_callbacks["task_started"].append(callback) + + def on_task_finished(self, callback: Callable): + """Callback when a task is finished + + Callback signature: callback(task_id: str, is_img2img: bool, is_ui: bool, status: TaskStatus, result: dict) + """ + + self.script_callbacks["task_finished"].append(callback) + + def on_task_cleared(self, callback: Callable): + self.script_callbacks["task_cleared"].append(callback) + + def __run_callbacks(self, name: str, *args, **kwargs): + for callback in self.script_callbacks[name]: + callback(*args, **kwargs) + + +def get_instance(block) -> TaskRunner: + if TaskRunner.instance is None: + txt2img_submit_button = get_component_by_elem_id(block, "txt2img_generate") + UiControlNetUnit = detect_control_net(block, txt2img_submit_button) + TaskRunner(UiControlNetUnit) + + def on_before_reload(): + # Tell old instance to stop + TaskRunner.instance.dispose = True + # force recreate the instance + TaskRunner.instance = None + + script_callbacks.on_before_reload(on_before_reload) + + return TaskRunner.instance diff --git a/scripts/task_scheduler.py b/scripts/task_scheduler.py new file mode 100644 index 0000000..f149f62 --- /dev/null +++ b/scripts/task_scheduler.py @@ -0,0 +1,205 @@ +import gradio as gr + +from modules import shared, script_callbacks, scripts +from modules.shared import list_checkpoint_tiles, refresh_checkpoints +from modules.ui import create_refresh_button + +from scripts.task_runner import TaskRunner, get_instance +from scripts.helpers import compare_components_with_ids, get_components_by_ids +from scripts.db import init, state_manager, AppStateKey + +task_runner: TaskRunner = None +initialized = False + +checkpoint_current = "Current Checkpoint" +checkpoint_runtime = "Runtime Checkpoint" + + +class Script(scripts.Script): + def __init__(self): + super().__init__() + script_callbacks.on_app_started(lambda block, _: self.on_app_started(block)) + self.checkpoint_override = checkpoint_current + + def title(self): + return "Agent Scheduler" + + def show(self, is_img2img): + return True + + def on_checkpoint_changed(self, checkpoint): + self.checkpoint_override = checkpoint + + def after_component(self, component, **_kwargs): + elem_id = "txt2img_generate" if self.is_txt2img else "img2img_generate" + + if component.elem_id == elem_id: + self.generate_button = component + + def on_app_started(self, block): + self.add_enqueue_button(block, self.generate_button) + + def add_enqueue_button(self, root: gr.Blocks, generate: gr.Button): + is_img2img = self.is_img2img + dependencies: list[dict] = [ + x + for x in root.dependencies + if x["trigger"] == "click" and generate._id in x["targets"] + ] + + dependency: dict = None + cnet_dependency: dict = None + UiControlNetUnit = None + for d in dependencies: + if len(d["outputs"]) == 1: + outputs = get_components_by_ids(root, d["outputs"]) + output = outputs[0] + if ( + isinstance(output, gr.State) + and type(output.value).__name__ == "UiControlNetUnit" + ): + cnet_dependency = d + UiControlNetUnit = type(output.value) + + elif len(d["outputs"]) == 4: + dependency = d + + fn_block = next( + fn + for fn in root.fns + if compare_components_with_ids(fn.inputs, dependency["inputs"]) + ) + fn = self.wrap_register_ui_task() + args = dict( + fn=fn, + _js="submit_enqueue_img2img" if is_img2img else "submit_enqueue", + inputs=fn_block.inputs, + outputs=fn_block.outputs, + show_progress=False, + ) + with root: + with generate.parent: + id_part = "img2img" if is_img2img else "txt2img" + with gr.Row(elem_id=f"{id_part}_enqueue_wrapper"): + checkpoint = gr.Dropdown( + choices=get_checkpoint_choices(), + value=checkpoint_current, + show_label=False, + interactive=True, + ) + create_refresh_button( + checkpoint, + refresh_checkpoints, + lambda: {"choices": get_checkpoint_choices()}, + f"refresh_{id_part}_checkpoint", + ) + submit = gr.Button( + "Enqueue", elem_id=f"{id_part}_enqueue", variant="primary" + ) + + checkpoint.change(fn=self.on_checkpoint_changed, inputs=[checkpoint]) + submit.click(**args) + + if cnet_dependency is not None: + cnet_fn_block = next( + fn + for fn in root.fns + if compare_components_with_ids(fn.inputs, cnet_dependency["inputs"]) + ) + with root: + submit.click( + fn=UiControlNetUnit, + inputs=cnet_fn_block.inputs, + outputs=cnet_fn_block.outputs, + queue=False, + ) + + def wrap_register_ui_task(self): + def f(*args, **kwargs): + if len(args) == 0 and len(kwargs) == 0: + raise Exception("Invalid call") + + if len(args) > 0 and type(args[0]) == str: + task_id = args[0] + else: + # not a task, exit + return (None, "", "

Invalid params

", "") + + checkpoint = None + if self.checkpoint_override == checkpoint_current: + checkpoint = shared.sd_model.sd_checkpoint_info.title + elif self.checkpoint_override != checkpoint_runtime: + checkpoint = self.checkpoint_override + + task_runner.register_ui_task( + task_id, self.is_img2img, *args, checkpoint=checkpoint + ) + task_runner.execute_pending_tasks_threading() + + return (None, "", "

Task queued

", "") + + return f + + +def get_checkpoint_choices(): + choices = [checkpoint_current, checkpoint_runtime] + choices.extend(list_checkpoint_tiles()) + return choices + + +def is_queue_paused(): + return state_manager.get_value(AppStateKey.QueueState) == "paused" + + +def on_ui_tab(**_kwargs): + global initialized + if not initialized: + initialized = True + init() + + with gr.Blocks(analytics_enabled=False) as scheduler_tab: + with gr.Row(elem_id="agent_scheduler_pending_tasks_wrapper"): + with gr.Column(scale=1): + with gr.Group(elem_id="agent_scheduler_actions"): + paused = is_queue_paused() + + pause = gr.Button( + "Pause", + elem_id="agent_scheduler_action_pause", + variant="stop", + visible=not paused, + ) + resume = gr.Button( + "Resume", + elem_id="agent_scheduler_action_resume", + variant="primary", + visible=paused, + ) + gr.Button( + "Refresh", + elem_id="agent_scheduler_action_refresh", + variant="secondary", + ) + gr.HTML('') + gr.HTML( + '
' + ) + with gr.Column(scale=1): + with gr.Group(elem_id="agent_scheduler_current_task_progress"): + gr.Gallery( + elem_id="agent_scheduler_current_task_images", + label="Output", + show_label=False, + ).style(grid=4) + + return [(scheduler_tab, "Agent Scheduler", "agent_scheduler")] + + +def on_app_started(block, _): + global task_runner + task_runner = get_instance(block) + task_runner.execute_pending_tasks_threading() + + +script_callbacks.on_ui_tabs(on_ui_tab) +script_callbacks.on_app_started(on_app_started) diff --git a/style.css b/style.css new file mode 100644 index 0000000..0ba0250 --- /dev/null +++ b/style.css @@ -0,0 +1,430 @@ +.ts-search { + position: relative; + margin-left: auto; + width: 100%; + max-width: 20rem; +} + +.ts-search-input { + display: block; + width: 100%; + border-radius: 0.375rem !important; + border-width: 1px; + --tw-border-opacity: 1; + border-color: rgb(209 213 219 / var(--tw-border-opacity)); + --tw-bg-opacity: 1; + background-color: rgb(249 250 251 / var(--tw-bg-opacity)); + padding: 0.5rem !important; + padding-left: 2.5rem !important; + font-size: 0.875rem; + line-height: 1.25rem; + --tw-text-opacity: 1; + color: rgb(17 24 39 / var(--tw-text-opacity)); +} + +.ts-search-input:focus { + --tw-border-opacity: 1; + border-color: rgb(59 130 246 / var(--tw-border-opacity)); + --tw-ring-opacity: 1; + --tw-ring-color: rgb(59 130 246 / var(--tw-ring-opacity)); +} + +:is(.dark .ts-search-input) { + --tw-border-opacity: 1 !important; + border-color: rgb(75 85 99 / var(--tw-border-opacity)) !important; + --tw-bg-opacity: 1 !important; + background-color: rgb(55 65 81 / var(--tw-bg-opacity)) !important; + --tw-text-opacity: 1 !important; + color: rgb(255 255 255 / var(--tw-text-opacity)) !important; +} + +:is(.dark .ts-search-input)::-moz-placeholder { + --tw-placeholder-opacity: 1 !important; + color: rgb(156 163 175 / var(--tw-placeholder-opacity)) !important; +} + +:is(.dark .ts-search-input)::placeholder { + --tw-placeholder-opacity: 1 !important; + color: rgb(156 163 175 / var(--tw-placeholder-opacity)) !important; +} + +:is(.dark .ts-search-input:focus) { + --tw-border-opacity: 1; + border-color: rgb(59 130 246 / var(--tw-border-opacity)); + --tw-ring-opacity: 1; + --tw-ring-color: rgb(59 130 246 / var(--tw-ring-opacity)); +} + +.ts-search-icon { + pointer-events: none; + position: absolute; + top: 0px; + bottom: 0px; + left: 0px; + display: flex; + align-items: center; + padding-left: 0.75rem; +} + +:is(.dark .ts-search-icon) { + --tw-text-opacity: 1; + color: rgb(255 255 255 / var(--tw-text-opacity)); +} + +.ts-btn-action { + margin: 0px !important; + display: inline-flex; + align-items: center; + border-width: 1px; + padding-left: 0.5rem !important; + padding-right: 0.5rem !important; + padding-top: 0.25rem !important; + padding-bottom: 0.25rem !important; + font-size: 0.875rem; + line-height: 1.25rem; + font-weight: 500; +} + +.ts-btn-action:focus { + z-index: 10; + --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) + var(--tw-ring-offset-color); + --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(2px + var(--tw-ring-offset-width)) + var(--tw-ring-color); + box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); +} + +.ts-btn-action:disabled { + cursor: not-allowed; + opacity: 0.5; +} + +.ts-btn-action:hover:disabled { + background-color: transparent !important; +} + +.ts-btn-run { + border-top-left-radius: 0.375rem; + border-bottom-left-radius: 0.375rem; + --tw-border-opacity: 1; + border-color: rgb(34 197 94 / var(--tw-border-opacity)); + --tw-text-opacity: 1 !important; + color: rgb(34 197 94 / var(--tw-text-opacity)) !important; +} + +.ts-btn-run:hover { + --tw-bg-opacity: 1; + background-color: rgb(22 163 74 / var(--tw-bg-opacity)); + --tw-text-opacity: 1 !important; + color: rgb(255 255 255 / var(--tw-text-opacity)) !important; +} + +.ts-btn-run:focus { + --tw-ring-opacity: 1; + --tw-ring-color: rgb(74 222 128 / var(--tw-ring-opacity)); +} + +.ts-btn-run:hover:disabled { + --tw-text-opacity: 1 !important; + color: rgb(34 197 94 / var(--tw-text-opacity)) !important; +} + +:is(.dark .ts-btn-run) { + --tw-border-opacity: 1; + border-color: rgb(34 197 94 / var(--tw-border-opacity)); +} + +:is(.dark .ts-btn-run:hover) { + --tw-bg-opacity: 1; + background-color: rgb(22 163 74 / var(--tw-bg-opacity)); +} + +:is(.dark .ts-btn-run:focus) { + --tw-ring-opacity: 1; + --tw-ring-color: rgb(20 83 45 / var(--tw-ring-opacity)); +} + +.ts-btn-delete { + border-top-right-radius: 0.375rem; + border-bottom-right-radius: 0.375rem; + --tw-border-opacity: 1; + border-color: rgb(220 38 38 / var(--tw-border-opacity)); + --tw-text-opacity: 1 !important; + color: rgb(239 68 68 / var(--tw-text-opacity)) !important; +} + +.ts-btn-delete:hover { + --tw-bg-opacity: 1; + background-color: rgb(220 38 38 / var(--tw-bg-opacity)); + --tw-text-opacity: 1 !important; + color: rgb(255 255 255 / var(--tw-text-opacity)) !important; +} + +.ts-btn-delete:focus { + --tw-ring-opacity: 1; + --tw-ring-color: rgb(252 165 165 / var(--tw-ring-opacity)); +} + +:is(.dark .ts-btn-delete) { + --tw-border-opacity: 1; + border-color: rgb(239 68 68 / var(--tw-border-opacity)); +} + +:is(.dark .ts-btn-delete:hover) { + --tw-bg-opacity: 1; + background-color: rgb(220 38 38 / var(--tw-bg-opacity)); +} + +:is(.dark .ts-btn-delete:focus) { + --tw-ring-opacity: 1; + --tw-ring-color: rgb(127 29 29 / var(--tw-ring-opacity)); +} + +.mt-1 { + margin-top: 0.25rem; +} + +.mt-1\.5 { + margin-top: 0.375rem; +} + +.inline-flex { + display: inline-flex; +} + +.grid { + display: grid; +} + +.rounded-md { + border-radius: 0.375rem; +} + +.shadow-sm { + --tw-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05); + --tw-shadow-colored: 0 1px 2px 0 var(--tw-shadow-color); + box-shadow: var(--tw-ring-offset-shadow, 0 0 #0000), var(--tw-ring-shadow, 0 0 #0000), + var(--tw-shadow); +} + +.filter { + filter: var(--tw-blur) var(--tw-brightness) var(--tw-contrast) var(--tw-grayscale) + var(--tw-hue-rotate) var(--tw-invert) var(--tw-saturate) var(--tw-sepia) var(--tw-drop-shadow); +} + +/****************************************************************/ + +#agent_scheduler_pending_tasks_wrapper { + gap: var(--layout-gap); + border: none; + box-shadow: none; + border-width: 0; +} + +@media (max-width: 1024px) { + #agent_scheduler_pending_tasks_wrapper { + flex-wrap: wrap; + } +} + +#agent_scheduler_pending_tasks_wrapper > div:last-child { + width: 100%; + max-width: 512px; +} + +#agent_scheduler_current_task_images { + width: 100%; + padding-top: 100%; + position: relative; +} + +#agent_scheduler_current_task_images > div { + position: absolute; + top: 0; + left: 0; + right: 0; + bottom: 0; + height: 100%; +} + +#agent_scheduler_pending_tasks_wrapper { + justify-content: flex-end; + gap: var(--layout-gap); + padding: 0 var(--layout-gap) var(--layout-gap) var(--layout-gap); +} + +#agent_scheduler_pending_tasks_wrapper > button { + flex: 0 0 auto; +} + +#agent_scheduler_actions { + display: flex; + gap: var(--layout-gap); +} + +#agent_scheduler_actions > button { + border-radius: var(--radius-lg) !important; +} + +@keyframes blink { + from, + to { + opacity: 0; + } + 50% { + opacity: 1; + } +} + +@-moz-keyframes blink { + from, + to { + opacity: 0; + } + 50% { + opacity: 1; + } +} + +@-webkit-keyframes blink { + from, + to { + opacity: 0; + } + 50% { + opacity: 1; + } +} + +@-ms-keyframes blink { + from, + to { + opacity: 0; + } + 50% { + opacity: 1; + } +} + +@-o-keyframes blink { + from, + to { + opacity: 0; + } + 50% { + opacity: 1; + } +} + +.ag-theme-alpine, +.ag-theme-alpine-dark { + --ag-row-height: 45px; + --ag-header-height: 45px; + /* --ag-grid-size: 6px; */ + --ag-cell-horizontal-padding: calc(var(--ag-grid-size) * 2); + + --body-text-color: 'inherit'; +} + +.task-running { + color: #52c41a !important; + + -webkit-animation: 1s blink ease infinite; + -moz-animation: 1s blink ease infinite; + -ms-animation: 1s blink ease infinite; + -o-animation: 1s blink ease infinite; + animation: 1s blink ease infinite; +} + +.generate-box { + gap: unset !important; +} + +#txt2img_generate, +#img2img_generate { + min-height: unset !important; +} + +.generate-box #txt2img_interrupt { + position: initial !important; + height: 42px; +} + +.generate-box #txt2img_interrupt, +.generate-box #img2img_interrupt, +.generate-box #txt2img_skip, +.generate-box #img2img_skip { + position: initial !important; + height: 42px; + margin-top: 0 !important; + display: none !important; + max-width: 50% !important; +} + +.black-orange .generate-box #txt2img_interrupt, +.black-orange .generate-box #img2img_interrupt, +.black-orange .generate-box #txt2img_skip, +.black-orange .generate-box #img2img_skip { + height: 36px !important; +} + +#txt2img_enqueue_wrapper, +#img2img_enqueue_wrapper { + flex-wrap: nowrap; + margin-top: calc(var(--layout-gap) / 2); + min-width: 100%; + gap: calc(var(--layout-gap) / 2); +} + +#txt2img_enqueue_wrapper > div:first-child, +#img2img_enqueue_wrapper > div:first-child { + flex: 1 1 auto; +} + +#txt2img_enqueue_wrapper > .gradio-button.primary, +#img2img_enqueue_wrapper > .gradio-button.primary { + flex: 0 0 auto; + min-width: 0; +} + +.black-orange #txt2img_enqueue_wrapper .gradio-button, +.black-orange #img2img_enqueue_wrapper .gradio-button, +.black-orange #txt2img_enqueue_wrapper .gradio-dropdown .wrap-inner, +.black-orange #img2img_enqueue_wrapper .gradio-dropdown .wrap-inner { + height: 36px; + padding: 6px 12px; +} + +.black-orange #txt2img_tools, +.black-orange #img2img_tools { + margin-top: 0; + margin-left: 0; + scale: 1; +} + +.black-orange #txt2img_tools .gradio-button, +.black-orange #img2img_tools .gradio-button { + min-width: 36px !important; + height: 36px; +} + +.black-orange #txt2img_actions_column, +.black-orange #img2img_actions_column { + min-width: min(320px, 100%) !important; +} + +.generate-box #txt2img_interrupt[style='display: block;'], +.generate-box #img2img_interrupt[style='display: block;'], +.generate-box #txt2img_skip[style='display: block;'], +.generate-box #img2img_skip[style='display: block;'] { + display: block !important; +} + +.generate-box #txt2img_skip[style='display: block;'] + #txt2img_generate, +.generate-box #img2img_skip[style='display: block;'] + #img2img_generate { + display: none !important; +} + +#agent_scheduler_current_task_progress .livePreview { + margin: 0; +}