first commit

pull/7/head
AutoAgentX 2023-05-29 21:31:46 +07:00
commit 21d150a0bd
19 changed files with 2518 additions and 0 deletions

40
.gitignore vendored Normal file
View File

@ -0,0 +1,40 @@
__pycache__
*.ckpt
*.safetensors
*.pth
/ESRGAN/*
/SwinIR/*
/repositories
/venv
/tmp
/model.ckpt
/models/**/*
/GFPGANv1.3.pth
/gfpgan/weights/*.pth
/ui-config.json
/outputs
/config.json
/log
/webui.settings.bat
/embeddings
/styles.csv
/params.txt
/styles.csv.bak
/webui-user.bat
/webui-user.sh
/interrogate
/user.css
/.idea
notification.mp3
/SwinIR
/textual_inversion
.vscode
/extensions
/test/stdout.txt
/test/stderr.txt
/cache.json
*.sql
*.db
*.sqlite
*.sqlite3
tailwind.*

9
.prettierrc Normal file
View File

@ -0,0 +1,9 @@
{
"singleQuote": true,
"jsxSingleQuote": false,
"arrowParens": "always",
"trailingComma": "all",
"semi": true,
"tabWidth": 2,
"printWidth": 100
}

124
README.md Normal file
View File

@ -0,0 +1,124 @@
# Agent Scheduler
Introducing AgentScheduler, an A1111/Vladmandic Stable Diffusion Web UI extension to power up your image generation workflow!
## Table of Content
- [Compatibility](#compatibility)
- [Functionality](#functionality--as-of-current-version-)
- [Installation](#installation)
- [Using the built-in extension list](#using-the-built-in-extension-list)
- [Manual clone](#manual-clone)
- [Road Map](#road-map)
- [Contributing](#contributing)
- [License](#license)
- [Disclaimer](#disclaimer)
---
## Compatibility
This version of AgentScheduler is compatible with:
- A1111: [commit 5ab7f213](https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/5ab7f213bec2f816f9c5644becb32eb72c8ffb89)
- Vladmandic: [commit 0a46f8ad](https://github.com/vladmandic/automatic/commit/0a46f8ada7751ee993c565198ec4c3327ff43c04)
> Older versions may not working properly.
## Functionality [as of current version]
![Extension Walkthrough](https://user-images.githubusercontent.com/90659883/236373779-43bb9625-d5d9-4450-abc5-93e7af7251fd.jpg)
1⃣ Input your usual Prompts & Settings. **Enqueue** to send your current prompts, settings, controlnets to **AgentScheduler**.
2**AgentScheduler** Tab Navigation.
3**Pause** function to stop auto generation. **Refresh** to update.
4⃣ See all queued tasks, current image being generated and tasks' associated information.
5**Delete** tasks that you no longer want. Press ▶️ to prioritize selected task.
6**Show/Hide** Column of Information that you want.
7**Drag and Drop** to reorder columns.
https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/50c74922-b85f-493c-9be8-b8e78f0cd061
## Road Map
To list possible feature upgrades for this extension
- Sync with GenAI Management Platform **ArtVenture**
- See history of completed jobs (Logs)
## Installation
### Using the built-in extension list
1. Open the Extensions tab
2. Open the "Install From URL" sub-tab
3. Paste the repo url: https://github.com/ArtVentureX/sd-webui-agent-scheduler.git
4. Click "Install"
![Install](/docs/images/install.png)
### Manual clone
```bash
git clone "https://github.com/ArtVentureX/sd-webui-agent-scheduler.git" extensions/agent-scheduler
```
(The second argument specifies the name of the folder, you can choose whatever you like).
## Contributing
We welcome contributions to the Agent Scheduler Extension project! Please feel free to submit issues, bug reports, and feature requests through the GitHub repository.
Please give us a ⭐ if you find this extension helpful!
## License
This project is licensed under the Apache License 2.0.
## Disclaimer
The author(s) of this project are not responsible for any damages or legal issues arising from the use of this software. Users are solely responsible for ensuring that they comply with any applicable laws and regulations when using this software and assume all risks associated with its use. The author(s) are not responsible for any copyright violations or legal issues arising from the use of input or output content.
---
## CRAFTED BY THE PEOPLE BUILDING **ARTVENTURE**, [**ATHERLABS**](https://atherlabs.com/) & [**SIPHER ODYSSEY**](http://playsipher.com/)
### About ArtVenture (coming soon™)
ArtVenture offers powerful collaboration features for Generative AI Image workflows. It is designed to help designers and creative professionals of all levels collaborate more efficiently, unleash their creativity, and have full transparency and tracking over the creation process.
![ArtVenture Teaser](https://user-images.githubusercontent.com/90659883/236376930-831ac345-e979-4ec5-bece-49e4bc497b79.png)
![ArtVenture Teaser 2](https://user-images.githubusercontent.com/90659883/236376933-babe9d36-f42f-4c1c-b59a-08be572a1f4c.png)
### Current Features
ArtVenture offers the following key features:
- Seamless Access: available on desktop and mobile
- Multiplayer & Collaborative UX. Strong collaboration features, such as real-time commenting and feedback, version control, and image/file/project sharing.
- Powerful semantic search capabilities.
- Building on shoulders of Giants, leveraging A1111/Vladnmandic and other pioneers, provide collaboration process from Idea (Sketch/Thoughts/Business Request) to Final Results(Images/Copywriting Post/TaskCompleted) in 1 platform
- Automation tooling for certain repeated tasks
- Secure and transparent, leveraging hasing and metadata to track the origin and history of models, loras, images to allow for tracability and ease of collaboration.
- Personalize UX for both beginner and experienced users to quickly remix existing SD images by editing prompts and negative prompts, selecting new training models and output quality as desired.
### Target Audience
ArtVenture is designed for the following target audiences:
- Casual Creators
- Small Design Teams or Freelancers
- Design Agencies & Studios
## 🎉 Stay Tuned for Updates
We hope you find this extension to be useful. We will be adding new features and improvements over time as we enhance this extension to support our creative workflows.
To stay up-to-date with the latest news and updates, be sure to follow us on GitHub and Twitter (coming soon™). We welcome your feedback and suggestions, and are excited to hear how AgentScheduler can help you streamline your workflow and unleash your creativity!

BIN
docs/images/enqueue.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 236 KiB

BIN
docs/images/install.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

BIN
docs/images/manage.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 407 KiB

BIN
docs/images/walkthrough.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 MiB

4
install.py Normal file
View File

@ -0,0 +1,4 @@
import launch
if not launch.is_installed("sqlalchemy"):
launch.run_pip("install sqlalchemy", "requirement for task-scheduler")

View File

@ -0,0 +1,492 @@
(function agent_scheduler_init() {
const head = document.head || document.querySelector('head');
window.__loaded_scripts = [];
const insertStyleTag = (href) => {
const style = document.createElement('link');
style.rel = 'stylesheet';
style.href = href;
head.appendChild(style);
};
const insertScriptTag = (src, onload) => {
const script = document.createElement('script');
script.type = 'text/javascript';
head.appendChild(script);
script.onload = onload;
script.src = src;
};
// load ag-grid
insertStyleTag('https://cdn.jsdelivr.net/npm/ag-grid-community@29.3.3/styles/ag-grid.css');
insertStyleTag(
'https://cdn.jsdelivr.net/npm/ag-grid-community@29.3.3/styles/ag-theme-alpine.css',
);
insertScriptTag(
'https://cdn.jsdelivr.net/npm/ag-grid-community@29.3.3/dist/ag-grid-community.min.noStyle.js',
() => {
window.__loaded_scripts.push('agGrid');
},
);
// load rxjs
insertScriptTag('https://cdnjs.cloudflare.com/ajax/libs/rxjs/7.8.1/rxjs.umd.min.js', () => {
window.__loaded_scripts.push('rxjs');
const observable = new rxjs.Observable((observer) => {
function submit_enqueue() {
var id = randomId();
var res = create_submit_args(arguments);
res[0] = id;
document.querySelector('#txt2img_enqueue').innerHTML = 'Queued';
setTimeout(() => {
document.querySelector('#txt2img_enqueue').innerHTML = 'Enqueue';
observer.next({ type: 'txt2img', id });
}, 1000);
return res;
}
function submit_enqueue_img2img() {
var id = randomId();
var res = create_submit_args(arguments);
res[0] = id;
res[1] = get_tab_index('mode_img2img');
document.querySelector('#img2img_enqueue').innerHTML = 'Queued';
setTimeout(() => {
document.querySelector('#img2img_enqueue').innerHTML = 'Enqueue';
observer.next({ type: 'txt2img', id });
}, 1000);
return res;
}
submit_enqueue.subscribe = observable.subscribe.bind(observable);
submit_enqueue_img2img.subscribe = observable.subscribe.bind(observable);
window.submit_enqueue = submit_enqueue;
window.submit_enqueue_img2img = submit_enqueue_img2img;
});
// register a dummy subscriber
observable.subscribe({
next: console.log,
error: console.error,
complete: console.log,
});
});
// notyf
insertStyleTag('https://cdn.jsdelivr.net/npm/notyf@3/notyf.min.css');
insertScriptTag('https://cdn.jsdelivr.net/npm/notyf@3/notyf.min.js', () => {
window.__loaded_scripts.push('notyf');
});
})();
onUiLoaded(function initTaskScheduler() {
if (window.__loaded_scripts.length < 3) {
return setTimeout(() => {
initTaskScheduler()
}, 200);
}
// detect black-orange theme
if (document.querySelector('link[href*="black-orange.css"]')) {
document.body.classList.add('black-orange');
}
// init notyf
const notyf = new Notyf();
// init state
const subject = new rxjs.Subject();
const store = {
subject,
init: (initialState) => {
return (store.subject = subject.pipe(rxjs.startWith(initialState)));
},
subscribe: (callback) => {
return store.subject.pipe(rxjs.pairwise()).subscribe(callback);
},
refresh: async () => {
return fetch('/agent-scheduler/v1/queue?limit=1000')
.then((response) => response.json())
.then((data) => {
const pending_tasks = data.pending_tasks.map((item) => ({
...item,
params: JSON.parse(item.params),
status: item.id === data.current_task_id ? 'running' : 'pending',
}));
store.subject.next({
...data,
pending_tasks,
});
});
},
pauseQueue: async () => {
return fetch('/agent-scheduler/v1/pause', { method: 'POST' })
.then((response) => response.json())
.then((data) => {
if (data.success) {
notyf.success(data.message);
} else {
notyf.error(data.message);
}
return store.refresh();
})
},
resumeQueue: async () => {
return fetch('/agent-scheduler/v1/resume', { method: 'POST' })
.then((response) => response.json())
.then((data) => {
if (data.success) {
notyf.success(data.message);
} else {
notyf.error(data.message);
}
return store.refresh();
})
},
runTask: async (id) => {
return fetch(`/agent-scheduler/v1/run/${id}`, { method: 'POST' })
.then((response) => response.json())
.then((data) => {
if (data.success) {
notyf.success(data.message);
} else {
notyf.error(data.message);
}
return store.refresh();
});
},
deleteTask: async (id) => {
return fetch(`/agent-scheduler/v1/delete/${id}`, { method: 'POST' })
.then((response) => response.json())
.then((data) => {
if (data.success) {
notyf.success(data.message);
} else {
notyf.error(data.message);
}
return store.refresh();
});
},
moveTask: async (id, overId) => {
return fetch(`/agent-scheduler/v1/move/${id}/${overId}`, { method: 'POST' })
.then((response) => response.json())
.then((data) => {
if (data.success) {
notyf.success(data.message);
} else {
notyf.error(data.message);
}
return store.refresh();
});
},
};
store.init({
current_task_id: null,
total_pending_tasks: 0,
pending_tasks: [],
paused: false
});
// init actions
const refreshButton = gradioApp().querySelector('#agent_scheduler_action_refresh');
refreshButton.addEventListener('click', () => {
store.refresh();
});
const pauseButton = gradioApp().querySelector('#agent_scheduler_action_pause');
pauseButton.addEventListener('click', () => {
store.pauseQueue();
});
const resumeButton = gradioApp().querySelector('#agent_scheduler_action_resume');
resumeButton.addEventListener('click', () => {
store.resumeQueue();
});
const searchContainer = gradioApp().querySelector('#agent_scheduler_action_search');
searchContainer.className = "ts-search";
searchContainer.innerHTML = `
<div class="ts-search-icon">
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
<path d="M10 10m-7 0a7 7 0 1 0 14 0a7 7 0 1 0 -14 0"/>
<path d="M21 21l-6 -6"/>
</svg>
</div>
<input type="text" id="agent_scheduler_search_input" class="ts-search-input" placeholder="Search" required>
`;
// init grid
const eGridDiv = gradioApp().querySelector('#agent_scheduler_pending_tasks_grid');
if (document.querySelector('.dark')) {
eGridDiv.className = 'ag-theme-alpine-dark';
}
const deleteIcon = `
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
<path d="M4 7l16 0"/>
<path d="M10 11l0 6"/>
<path d="M14 11l0 6"/>
<path d="M5 7l1 12a2 2 0 0 0 2 2h8a2 2 0 0 0 2 -2l1 -12"/>
<path d="M9 7v-3a1 1 0 0 1 1 -1h4a1 1 0 0 1 1 1v3"/>
</svg>`;
const cancelIcon = `
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
<path d="M18 6l-12 12"/>
<path d="M6 6l12 12"/>
</svg>
`;
const pendingTasksGridOptions = {
domLayout: 'autoHeight',
// each entry here represents one column
columnDefs: [
{
field: 'id',
headerName: 'Task Id',
minWidth: 240,
maxWidth: 240,
pinned: 'left',
rowDrag: true,
cellClass: ({ data }) => (data.status === 'running' ? 'task-running' : ''),
},
{ field: 'type', headerName: 'Type', minWidth: 80, maxWidth: 80 },
{
field: 'priority',
headerName: 'Priority',
minWidth: 120,
maxWidth: 120,
filter: 'agNumberColumnFilter',
valueGetter: ({ data }) => data.priority - 1_681_000_000_000,
hide: true,
},
{
headerName: 'Params',
children: [
{
field: 'params.checkpoint',
headerName: 'Checkpoint', minWidth: 150, maxWidth: 300,
valueFormatter: ({ value }) => value || 'System'
},
{
field: 'params.prompt',
headerName: 'Prompt',
width: 400,
minWidth: 200,
wrapText: true,
autoHeight: true,
cellStyle: { 'line-height': '24px', 'padding-top': '8px', 'padding-bottom': '8px' },
},
{
field: 'params.negative_prompt',
headerName: 'Negative Prompt',
width: 400,
minWidth: 200,
wrapText: true,
autoHeight: true,
cellStyle: { 'line-height': '24px', 'padding-top': '8px', 'padding-bottom': '8px' },
},
{
field: 'params.sampler_name',
headerName: 'Sampler',
width: 150,
minWidth: 150,
},
{
field: 'params.steps',
headerName: 'Steps',
minWidth: 80,
maxWidth: 80,
filter: 'agNumberColumnFilter',
},
{
field: 'params.cfg_scale',
headerName: 'CFG Scale',
width: 100,
minWidth: 100,
filter: 'agNumberColumnFilter',
},
{
field: 'params.size',
headerName: 'Size',
minWidth: 110,
maxWidth: 110,
valueGetter: ({ data }) => `${data.params.width}x${data.params.height}`,
},
{
field: 'params.batch',
headerName: 'Batching',
minWidth: 100,
maxWidth: 100,
valueGetter: ({ data }) => `${data.params.n_iter}x${data.params.batch_size}`,
},
],
},
{ field: 'created_at', headerName: 'Date', minWidth: 200 },
{
headerName: 'Action',
pinned: 'right',
minWidth: 110,
maxWidth: 110,
resizable: false,
valueGetter: ({ data }) => data.id,
cellRenderer: ({ api, value, data }) => {
const html = `
<div class="inline-flex rounded-md shadow-sm mt-1.5" role="group">
<button type="button" ${data.status === 'running' ? 'disabled' : ''} class="ts-btn-action ts-btn-run">
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
<path d="M7 4v16l13 -8z"/>
</svg>
</button>
<button type="button" class="ts-btn-action ts-btn-delete">
${data.status === 'pending' ? deleteIcon : cancelIcon}
</button>
</div>
`;
const placeholder = document.createElement('div');
placeholder.innerHTML = html;
const node = placeholder.firstElementChild;
const btnRun = node.querySelector('button.ts-btn-run');
btnRun.addEventListener('click', () => {
console.log('run', value);
api.showLoadingOverlay();
store.runTask(value).then(() => api.hideOverlay());
});
const btnDelete = node.querySelector('button.ts-btn-delete');
btnDelete.addEventListener('click', () => {
console.log('delete', value);
api.showLoadingOverlay();
store.deleteTask(value).then(() => api.hideOverlay());
});
return node;
},
},
],
getRowId: ({ data }) => data.id,
// default col def properties get applied to all columns
defaultColDef: { sortable: false, filter: true, resizable: true, suppressMenu: true },
rowSelection: 'single', // allow rows to be selected
animateRows: true, // have rows animate to new positions when sorted
pagination: true,
paginationPageSize: 10,
getContextMenuItems: () => [],
suppressCopyRowsToClipboard: true,
sideBar: {
toolPanels: [
{
id: 'columns',
labelDefault: 'Columns',
labelKey: 'columns',
iconKey: 'columns',
toolPanel: 'agColumnsToolPanel',
toolPanelParams: {
suppressRowGroups: true,
suppressValues: true,
suppressPivots: true,
suppressPivotMode: true,
},
},
{
id: 'filters',
labelDefault: 'Filters',
labelKey: 'filters',
iconKey: 'filter',
toolPanel: 'agFiltersToolPanel',
},
],
position: 'right',
},
onGridReady: ({ api, columnApi }) => {
// init quick search input
const searchInput = searchContainer.querySelector('input#agent_scheduler_search_input');
rxjs
.fromEvent(searchInput, 'input')
.pipe(rxjs.debounce(() => rxjs.interval(200)))
.subscribe((e) => {
api.setQuickFilter(e.target.value);
});
store.subscribe({
next: ([_, newState]) => {
api.setRowData(newState.pending_tasks);
if (newState.current_task_id) {
const node = api.getRowNode(newState.current_task_id);
if (node) {
api.refreshCells({ rowNodes: [node], force: true });
}
}
columnApi.autoSizeColumns();
},
});
},
onRowDragEnd: ({ api, node, overNode }) => {
const id = node.data.id;
const overId = overNode.data.id;
api.showLoadingOverlay();
store.moveTask(id, overId).then(() => api.hideOverlay());
},
};
new agGrid.Grid(eGridDiv, pendingTasksGridOptions);
// watch for current task id change
const onTaskIdChange = (id) => {
if (id) {
requestProgress(
id,
gradioApp().getElementById('agent_scheduler_current_task_progress'),
gradioApp().getElementById('agent_scheduler_current_task_images'),
() => {
setTimeout(() => {
store.refresh();
}, 1000);
},
);
}
};
store.subscribe({
next: ([prev, curr]) => {
if (prev.current_task_id !== curr.current_task_id) {
onTaskIdChange(curr.current_task_id);
}
if (curr.paused) {
pauseButton.classList.add('hide');
resumeButton.classList.remove('hide');
} else {
pauseButton.classList.remove('hide');
resumeButton.classList.add('hide');
}
},
});
// watch for task submission
window.submit_enqueue.subscribe({
next: () => store.refresh(),
});
// refresh the state
store.refresh();
});

123
scripts/api.py Normal file
View File

@ -0,0 +1,123 @@
import json
import threading
from gradio.routes import App
import modules.shared as shared
from modules import progress, script_callbacks, sd_samplers
from scripts.db import TaskStatus, AppStateKey, task_manager, state_manager
from scripts.models import QueueStatusResponse
from scripts.task_runner import TaskRunner, get_instance
from scripts.helpers import log
task_runner: TaskRunner = None
def regsiter_apis(app: App):
log.info("[AgentScheduler] Registering APIs")
@app.get("/agent-scheduler/v1/queue", response_model=QueueStatusResponse)
def queue_status_api(limit: int = 20, offset: int = 0):
current_task_id = progress.current_task
total_pending_tasks = total_pending_tasks = task_manager.count_tasks(
status="pending"
)
pending_tasks = task_manager.get_tasks(
status=TaskStatus.PENDING, limit=limit, offset=offset
)
for task in pending_tasks:
task_args = TaskRunner.instance.parse_task_args(
task.params, task.script_params, deserialization=False
)
named_args = task_args["named_args"]
named_args["checkpoint"] = task_args["checkpoint"]
sampler_index = named_args.get("sampler_index", None)
if sampler_index is not None:
named_args["sampler_name"] = sd_samplers.samplers[
named_args["sampler_index"]
].name
task.params = json.dumps(named_args)
return QueueStatusResponse(
current_task_id=current_task_id,
pending_tasks=pending_tasks,
total_pending_tasks=total_pending_tasks,
paused=TaskRunner.instance.paused,
)
@app.post("/agent-scheduler/v1/run/{id}")
def run_task(id: str):
if progress.current_task is not None:
if progress.current_task == id:
return {"success": False, "message": f"Task {id} is already running"}
else:
# move task up in queue
task_manager.prioritize_task(id, 0)
return {
"success": True,
"message": f"Task {id} is scheduled to run next",
}
else:
# run task
task = task_manager.get_task(id)
current_thread = threading.Thread(
target=TaskRunner.instance.execute_task,
args=(
task,
lambda: None,
),
)
current_thread.daemon = True
current_thread.start()
return {"success": True, "message": f"Task {id} is executing"}
@app.post("/agent-scheduler/v1/delete/{id}")
def delete_task(id: str):
if progress.current_task == id:
shared.state.interrupt()
return {"success": True, "message": f"Task {id} is interrupted"}
task_manager.delete_task(id)
return {"success": True, "message": f"Task {id} is deleted"}
@app.post("/agent-scheduler/v1/move/{id}/{over_id}")
def move_task(id: str, over_id: str):
task = task_manager.get_task(id)
if task is None:
return {"success": False, "message": f"Task {id} not found"}
if over_id == "top":
task_manager.prioritize_task(id, 0)
return {"success": True, "message": f"Task {id} is moved to top"}
elif over_id == "bottom":
task_manager.prioritize_task(id, -1)
return {"success": True, "message": f"Task {id} is moved to bottom"}
else:
over_task = task_manager.get_task(over_id)
if over_task is None:
return {"success": False, "message": f"Task {over_id} not found"}
task_manager.prioritize_task(id, over_task.priority)
return {"success": True, "message": f"Task {id} is moved"}
@app.post("/agent-scheduler/v1/pause")
def pause_queue():
state_manager.set_value(AppStateKey.QueueState, "paused")
return {"success": True, "message": f"Queue is paused"}
@app.post("/agent-scheduler/v1/resume")
def resume_queue():
state_manager.set_value(AppStateKey.QueueState, "running")
TaskRunner.instance.execute_pending_tasks_threading()
return {"success": True, "message": f"Queue is resumed"}
def on_app_started(block, app: App):
global task_runner
task_runner = get_instance(block)
regsiter_apis(app)
script_callbacks.on_app_started(on_app_started)

74
scripts/db/__init__.py Normal file
View File

@ -0,0 +1,74 @@
from pathlib import Path
from sqlalchemy import create_engine, inspect, text, String, Text
from .base import Base, metadata, db_file
from .app_state import AppStateKey, AppState, AppStateManager
from .task import TaskStatus, Task, TaskManager
version = "2"
state_manager = AppStateManager()
task_manager = TaskManager()
def init():
engine = create_engine(f"sqlite:///{db_file}")
# check if database exists
if not Path(db_file).exists():
# create database
metadata.create_all(engine)
state_manager.set_value(AppStateKey.Version, version)
# check if app state exists
if state_manager.get_value(AppStateKey.QueueState) is None:
# create app state
state_manager.set_value(AppStateKey.QueueState, "running")
inspector = inspect(engine)
with engine.connect() as conn:
# check if table task has column result and add it if not
task_columns = inspector.get_columns("task")
if not any(col["name"] == "result" for col in task_columns):
conn.execute(text("ALTER TABLE task ADD COLUMN result TEXT"))
params_column = next(col for col in task_columns if col["name"] == "params")
if version > "1" and not isinstance(params_column["type"], Text):
transaction = conn.begin()
conn.execute(
text(
"""
CREATE TABLE task_temp (
id VARCHAR(64) NOT NULL,
type VARCHAR(20) NOT NULL,
params TEXT NOT NULL,
script_params BLOB NOT NULL,
priority INTEGER NOT NULL,
status VARCHAR(20) NOT NULL,
created_at DATETIME DEFAULT (datetime('now')) NOT NULL,
updated_at DATETIME DEFAULT (datetime('now')) NOT NULL,
result TEXT,
PRIMARY KEY (id)
)"""
)
)
conn.execute(text("INSERT INTO task_temp SELECT * FROM task"))
conn.execute(text("DROP TABLE task"))
conn.execute(text("ALTER TABLE task_temp RENAME TO task"))
transaction.commit()
conn.close()
__all__ = [
"init",
"Base",
"metadata",
"db_file",
"AppStateKey",
"AppState",
"TaskStatus",
"Task",
"task_manager",
"state_manager",
]

79
scripts/db/app_state.py Normal file
View File

@ -0,0 +1,79 @@
from enum import Enum
from sqlalchemy import Column, String
from sqlalchemy.orm import Session
from .base import BaseTableManager, Base
class AppStateKey(str, Enum):
Version = "version"
QueueState = "queue_state" # paused or running
class AppState:
def __init__(self, key: str, value: str):
self.key: str = key
self.value: str = value
@staticmethod
def from_table(table: "AppStateTable"):
return AppState(table.key, table.value)
def to_table(self):
return AppStateTable(key=self.key, value=self.value)
class AppStateTable(Base):
__tablename__ = "app_state"
key = Column(String(64), primary_key=True)
value = Column(String(255), nullable=True)
def __repr__(self):
return f"AppState(key={self.key!r}, value={self.value!r})"
class AppStateManager(BaseTableManager):
def get_value(self, key: str) -> str | None:
session = Session(self.engine)
try:
result = session.get(AppStateTable, key)
if result:
return result.value
else:
return None
except Exception as e:
print(f"Exception getting value from database: {e}")
raise e
finally:
session.close()
def set_value(self, key: str, value: str):
session = Session(self.engine)
try:
result = session.get(AppStateTable, key)
if result:
result.value = value
else:
result = AppStateTable(key=key, value=value)
session.add(result)
session.commit()
except Exception as e:
print(f"Exception setting value in database: {e}")
raise e
finally:
session.close()
def delete_value(self, key: str):
session = Session(self.engine)
try:
result = session.get(AppStateTable, key)
if result:
session.delete(result)
session.commit()
except Exception as e:
print(f"Exception deleting value from database: {e}")
raise e
finally:
session.close()

30
scripts/db/base.py Normal file
View File

@ -0,0 +1,30 @@
import os
from sqlalchemy import create_engine
from sqlalchemy.schema import MetaData
from sqlalchemy.orm import declarative_base
from modules import scripts
Base = declarative_base()
metadata: MetaData = Base.metadata
db_file = os.path.join(scripts.basedir(), "task_scheduler.sqlite3")
class BaseTableManager:
def __init__(self, engine = None):
# Get the db connection object, making the file and tables if needed.
try:
self.engine = engine if engine else create_engine(f"sqlite:///{db_file}")
except Exception as e:
print(f"Exception connecting to database: {e}")
raise e
def get_engine(self):
return self.engine
# Commit and close the database connection.
def quit(self):
self.engine.dispose()

289
scripts/db/task.py Normal file
View File

@ -0,0 +1,289 @@
from enum import Enum
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, String, Text, Integer, DateTime, LargeBinary, text, func
from sqlalchemy.orm import Session
from .base import BaseTableManager, Base
from ..models import TaskModel
class TaskStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
DONE = "done"
FAILED = "failed"
class Task(TaskModel):
script_params: bytes = None
def __init__(
self,
id: str = "",
type: str = "unknown",
params: str = "",
script_params: bytes = b"",
priority: int = None,
status: str = TaskStatus.PENDING.value,
result: str = None,
created_at: Optional[datetime] = None,
updated_at: Optional[datetime] = None,
):
priority = priority if priority else int(datetime.utcnow().timestamp() * 1000)
super().__init__(
id=id,
type=type,
params=params,
status=status,
priority=priority,
result=result,
created_at=created_at,
updated_at=created_at,
)
self.id: str = id
self.type: str = type
self.params: str = params
self.script_params: bytes = script_params
self.priority: int = priority
self.status: str = status
self.result: str = result
self.created_at: datetime = created_at
self.updated_at: datetime = updated_at
class Config(TaskModel.__config__):
exclude = ["script_params"]
@staticmethod
def from_table(table: "TaskTable"):
return Task(
id=table.id,
type=table.type,
params=table.params,
script_params=table.script_params,
priority=table.priority,
status=table.status,
created_at=table.created_at,
updated_at=table.updated_at,
)
def to_table(self):
return TaskTable(
id=self.id,
type=self.type,
params=self.params,
script_params=self.script_params,
priority=self.priority,
status=self.status,
)
class TaskTable(Base):
__tablename__ = "task"
id = Column(String(64), primary_key=True)
type = Column(String(20), nullable=False) # txt2img or img2txt
params = Column(Text, nullable=False) # task args
script_params = Column(LargeBinary, nullable=False) # script args
priority = Column(Integer, nullable=False, default=datetime.now)
status = Column(
String(20), nullable=False, default="pending"
) # pending, running, done, failed
result = Column(Text) # task result
created_at = Column(
DateTime,
nullable=False,
server_default=text("(datetime('now'))"),
)
updated_at = Column(
DateTime,
nullable=False,
server_default=text("(datetime('now'))"),
onupdate=text("(datetime('now'))"),
)
def __repr__(self):
return f"Task(id={self.id!r}, type={self.type!r}, params={self.params!r}, status={self.status!r}, created_at={self.created_at!r})"
class TaskManager(BaseTableManager):
def get_task(self, id: str) -> TaskTable | None:
session = Session(self.engine)
try:
task = session.get(TaskTable, id)
return Task.from_table(task) if task else None
except Exception as e:
print(f"Exception getting task from database: {e}")
raise e
finally:
session.close()
def get_tasks(
self,
type: str = None,
status: str = None,
limit: int = None,
offset: int = None,
) -> list[TaskTable]:
session = Session(self.engine)
try:
query = session.query(TaskTable)
if type:
query = query.filter(TaskTable.type == type)
if status:
query = query.filter(TaskTable.status == status)
query = query.order_by(TaskTable.priority.asc()).order_by(
TaskTable.created_at.asc()
)
if limit:
query = query.limit(limit)
if offset:
query = query.offset(offset)
all = query.all()
return [Task.from_table(t) for t in all]
except Exception as e:
print(f"Exception getting tasks from database: {e}")
raise e
finally:
session.close()
def count_tasks(
self,
type: str = None,
status: str = None,
) -> int:
session = Session(self.engine)
try:
query = session.query(TaskTable)
if type:
query = query.filter(TaskTable.type == type)
if status:
query = query.filter(TaskTable.status == status)
return query.count()
except Exception as e:
print(f"Exception counting tasks from database: {e}")
raise e
finally:
session.close()
def add_task(self, task: Task) -> TaskTable:
session = Session(self.engine)
try:
result = task.to_table()
session.add(result)
session.commit()
return result
except Exception as e:
print(f"Exception adding task to database: {e}")
raise e
finally:
session.close()
def update_task(self, id: str, status: str, result=None) -> TaskTable:
session = Session(self.engine)
try:
task = session.get(TaskTable, id)
if task:
task.status = status
task.result = result
session.commit()
return task
else:
raise Exception(f"Task with id {id} not found")
except Exception as e:
print(f"Exception updating task in database: {e}")
raise e
finally:
session.close()
def prioritize_task(self, id: str, priority: int) -> TaskTable:
"""0 means move to top, -1 means move to bottom, otherwise set the exact priority"""
session = Session(self.engine)
try:
result = session.get(TaskTable, id)
if result:
if priority == 0:
result.priority = self.__get_min_priority() - 1
elif priority == -1:
result.priority = int(datetime.utcnow().timestamp() * 1000)
else:
self.__move_tasks_down(priority)
session.execute(text("SELECT 1"))
result.priority = priority
session.commit()
return result
else:
raise Exception(f"Task with id {id} not found")
except Exception as e:
print(f"Exception updating task in database: {e}")
raise e
finally:
session.close()
def delete_task(self, id: str):
session = Session(self.engine)
try:
result = session.get(TaskTable, id)
if result:
session.delete(result)
session.commit()
else:
raise Exception(f"Task with id {id} not found")
except Exception as e:
print(f"Exception deleting task from database: {e}")
raise e
finally:
session.close()
def delete_tasks_before(self, before: datetime, all: bool = False):
session = Session(self.engine)
try:
query = session.query(TaskTable).filter(TaskTable.created_at < before)
if not all:
query = query.filter(
TaskTable.status.in_([TaskStatus.DONE, TaskStatus.FAILED])
)
query.delete()
session.commit()
except Exception as e:
print(f"Exception deleting tasks from database: {e}")
raise e
finally:
session.close()
def __get_min_priority(self) -> int:
session = Session(self.engine)
try:
min_priority = session.query(func.min(TaskTable.priority)).scalar()
return min_priority if min_priority else 0
except Exception as e:
print(f"Exception getting min priority from database: {e}")
raise e
finally:
session.close()
def __move_tasks_down(self, priority: int):
session = Session(self.engine)
try:
session.query(TaskTable).filter(TaskTable.priority >= priority).update(
{TaskTable.priority: TaskTable.priority + 1}
)
session.commit()
except Exception as e:
print(f"Exception moving tasks down in database: {e}")
raise e
finally:
session.close()

80
scripts/helpers.py Normal file
View File

@ -0,0 +1,80 @@
import abc
import logging
import gradio as gr
from gradio.blocks import Block, BlockContext
if not logging.getLogger().hasHandlers():
# Logging is not set up
logging.basicConfig(level=logging.INFO, format='%(message)s')
log = logging.getLogger("sd")
class Singleton(abc.ABCMeta, type):
"""
Singleton metaclass for ensuring only one instance of a class.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
"""Call method for the singleton metaclass."""
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
def compare_components_with_ids(components: list[Block], ids: list[int]):
return len(components) == len(ids) and all(
component._id == _id for component, _id in zip(components, ids)
)
def get_component_by_elem_id(root: Block, elem_id: str):
if root.elem_id == elem_id:
return root
elem = None
if isinstance(root, BlockContext):
for block in root.children:
elem = get_component_by_elem_id(block, elem_id)
if elem is not None:
break
return elem
def get_components_by_ids(root: Block, ids: list[int]):
components: list[Block] = []
if root._id in ids:
components.append(root)
ids = [_id for _id in ids if _id != root._id]
if isinstance(root, BlockContext):
for block in root.children:
components.extend(get_components_by_ids(block, ids))
return components
def detect_control_net(root: gr.Blocks, submit: gr.Button):
UiControlNetUnit = None
dependencies: list[dict] = [
x
for x in root.dependencies
if x["trigger"] == "click" and submit._id in x["targets"]
]
for d in dependencies:
if len(d["outputs"]) == 1:
outputs = get_components_by_ids(root, d["outputs"])
output = outputs[0]
if (
isinstance(output, gr.State)
and type(output.value).__name__ == "UiControlNetUnit"
):
UiControlNetUnit = type(output.value)
return UiControlNetUnit

41
scripts/models.py Normal file
View File

@ -0,0 +1,41 @@
from datetime import datetime, timezone
from typing import Optional, List
from pydantic import BaseModel, Field
def convert_datetime_to_iso_8601_with_z_suffix(dt: datetime) -> str:
return dt.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z' if dt else None
def transform_to_utc_datetime(dt: datetime) -> datetime:
return dt.astimezone(tz=timezone.utc)
class QueueStatusAPI(BaseModel):
limit: Optional[int] = Field(title="Limit", description="The maximum number of tasks to return", default=20)
offset: Optional[int] = Field(title="Offset", description="The offset of the tasks to return", default=0)
class TaskModel(BaseModel):
id: str = Field(title="Task Id")
type: str = Field(title="Task Type", description="Either txt2img or img2img")
status: str = Field(title="Task Status", description="Either pending, running, done or failed")
params: str = Field(title="Task Parameters", description="The parameters of the task in JSON format")
priority: int = Field(title="Task Priority")
result: Optional[str] = Field(title="Task Result", description="The result of the task in JSON format")
created_at: Optional[datetime] = Field(title="Task Created At", description="The time when the task was created", default=None)
updated_at: Optional[datetime] = Field(title="Task Updated At", description="The time when the task was updated", default=None)
class Config:
json_encoders = {
# custom output conversion for datetime
datetime: convert_datetime_to_iso_8601_with_z_suffix
}
class QueueStatusResponse(BaseModel):
current_task_id: Optional[str] = Field(title="Current Task Id", description="The on progress task id")
pending_tasks: List[TaskModel] = Field(title="Pending Tasks", description="The pending tasks in the queue")
total_pending_tasks: int = Field(title="Queue length", description="The total pending tasks in the queue")
paused: bool = Field(title="Paused", description="Whether the queue is paused")

498
scripts/task_runner.py Normal file
View File

@ -0,0 +1,498 @@
import io
import sys
import json
import time
import zlib
import base64
import pickle
import inspect
import traceback
import threading
import numpy as np
from datetime import datetime, timedelta
from enum import Enum
from PIL import Image, PngImagePlugin
from typing import Any, Callable, Union
from fastapi import FastAPI
from modules import progress, shared, script_callbacks
from modules.call_queue import queue_lock, wrap_gradio_call
from modules.txt2img import txt2img
from modules.img2img import img2img
from modules.api.api import Api, encode_pil_to_base64
from modules.api.models import (
StableDiffusionTxt2ImgProcessingAPI,
StableDiffusionImg2ImgProcessingAPI,
)
from scripts.db import TaskStatus, AppStateKey, Task, task_manager, state_manager
from scripts.helpers import log, detect_control_net, get_component_by_elem_id
img2img_image_args_by_mode: dict[int, list[list[str]]] = {
0: [["init_img"]],
1: [["sketch"]],
2: [["init_img_with_mask", "image"], ["init_img_with_mask", "mask"]],
3: [["inpaint_color_sketch"], ["inpaint_color_sketch_orig"]],
4: [["init_img_inpaint"], ["init_mask_inpaint"]],
}
class TaskRunner:
instance = None
def __init__(self, UiControlNetUnit=None):
self.UiControlNetUnit = UiControlNetUnit
self.__total_pending_tasks: int = 0
self.__current_thread: threading.Thread = None
self.__api = Api(FastAPI(), queue_lock)
self.__saved_images_path: list[str] = []
script_callbacks.on_image_saved(self.__on_image_saved)
self.script_callbacks = {
"task_registered": [],
"task_started": [],
"task_finished": [],
"task_cleared": [],
}
# Mark this to True when reload UI
self.dispose = False
if TaskRunner.instance is not None:
raise Exception("TaskRunner instance already exists")
TaskRunner.instance = self
@property
def current_task_id(self) -> Union[str, None]:
return progress.current_task
@property
def is_executing_task(self) -> bool:
return self.__current_thread and self.__current_thread.is_alive()
@property
def paused(self) -> bool:
return state_manager.get_value(AppStateKey.QueueState) == "paused"
def __serialize_image(self, image):
if isinstance(image, np.ndarray):
shape = image.shape
data = base64.b64encode(zlib.compress(image.tobytes())).decode()
return {"shape": shape, "data": data, "cls": "ndarray"}
elif isinstance(image, Image.Image):
size = image.size
mode = image.mode
data = base64.b64encode(zlib.compress(image.tobytes())).decode()
return {
"size": size,
"mode": mode,
"data": data,
"cls": "Image",
}
else:
return image
def __deserialize_image(self, image_str):
if isinstance(image_str, dict) and image_str.get("cls", None):
cls = image_str["cls"]
data = zlib.decompress(base64.b64decode(image_str["data"]))
if cls == "ndarray":
shape = tuple(image_str["shape"])
image = np.frombuffer(data, dtype=np.uint8)
return image.reshape(shape)
else:
size = tuple(image_str["size"])
mode = image_str["mode"]
return Image.frombytes(mode, size, data)
else:
return image_str
def __serialize_img2img_images(self, args: dict, image_args: list):
for keys in image_args:
if len(keys) == 1:
image = args.get(keys[0], None)
args[keys[0]] = self.__serialize_image(image)
else:
value = args.get(keys[0], {})
image = value.get(keys[1], None)
value[keys[1]] = self.__serialize_image(image)
args[keys[0]] = value
def __deserialize_img2img_images(self, args: dict, image_args: list):
for keys in image_args:
if len(keys) == 1:
image = args.get(keys[0], None)
args[keys[0]] = self.__deserialize_image(image)
else:
value = args.get(keys[0], {})
image = value.get(keys[1], None)
value[keys[1]] = self.__deserialize_image(image)
args[keys[0]] = value
def __serialize_ui_task_args(self, is_img2img: bool, *args, checkpoint: str = None):
args_name = []
if is_img2img:
args_name = inspect.getfullargspec(img2img).args
else:
args_name = inspect.getfullargspec(txt2img).args
args = list(args)
named_args = dict(zip(args_name, args[0 : len(args_name)]))
script_args = args[len(args_name) :]
if checkpoint:
override_settings_texts = named_args.get("override_settings_texts", [])
override_settings_texts.append("Model hash: " + checkpoint)
named_args["override_settings_texts"] = override_settings_texts
# loop through named_args and serialize images
if is_img2img:
for mode, image_args in img2img_image_args_by_mode.items():
if mode == named_args["mode"]:
self.__serialize_img2img_images(named_args, image_args)
else:
# set None to unused image args to save space
for keys in image_args:
named_args[keys[0]] = None
# loop through script_args and serialize controlnets
if self.UiControlNetUnit is not None:
for i, a in enumerate(script_args):
if isinstance(a, self.UiControlNetUnit):
script_args[i] = a.__dict__
script_args[i]["is_cnet"] = True
for k, v in script_args[i].items():
if k == "image" and v is not None:
script_args[i][k] = {
"image": self.__serialize_image(v["image"]),
"mask": self.__serialize_image(v["mask"]),
}
if isinstance(v, Enum):
script_args[i][k] = v.value
return json.dumps(
{
"args": named_args,
"script_args": script_args,
"checkpoint": checkpoint,
"is_ui": True,
"is_img2img": is_img2img,
}
)
def __serialize_api_task_args(
self, is_img2img: bool, script_args: list = [], **named_args
):
override_settings = named_args.get("override_settings", {})
checkpoint = override_settings.get("sd_model_checkpoint", None)
return json.dumps(
{
"args": named_args,
"script_args": script_args,
"checkpoint": checkpoint,
"is_ui": False,
"is_img2img": is_img2img,
}
)
def __deserialize_ui_task_args(
self, is_img2img: bool, named_args: dict, script_args: list
):
# loop through image_args and deserialize images
if is_img2img:
for mode, image_args in img2img_image_args_by_mode.items():
if mode == named_args["mode"]:
self.__deserialize_img2img_images(named_args, image_args)
# loop through script_args and deserialize controlnets
if self.UiControlNetUnit is not None:
for i, arg in enumerate(script_args):
if isinstance(arg, dict) and arg.get("is_cnet", False):
arg.pop("is_cnet")
for k, v in arg.items():
if k == "image" and v is not None:
arg[k] = {
"image": self.__deserialize_image(v["image"]),
"mask": self.__deserialize_image(v["mask"]),
}
def parse_task_args(
self, params: str, script_params: bytes, deserialization: bool = True
):
parsed: dict[str, Any] = json.loads(params)
is_ui = parsed.get("is_ui", True)
is_img2img = parsed.get("is_img2img", None)
checkpoint = parsed.get("checkpoint", None)
named_args: dict[str, Any] = parsed["args"]
script_args: list[Any] = (
parsed["script_args"]
if "script_args" in parsed
else pickle.loads(script_params)
)
if is_ui and deserialization:
self.__deserialize_ui_task_args(is_img2img, named_args, script_args)
args = list(named_args.values()) + script_args
return {
"args": args,
"named_args": named_args,
"script_args": script_args,
"checkpoint": checkpoint,
"is_ui": is_ui,
}
def register_ui_task(
self, task_id: str, is_img2img: bool, *args, checkpoint: str = None
):
progress.add_task_to_queue(task_id)
params = self.__serialize_ui_task_args(is_img2img, *args, checkpoint=checkpoint)
task_type = "img2img" if is_img2img else "txt2img"
task_manager.add_task(Task(id=task_id, type=task_type, params=params))
self.__run_callbacks(
"task_registered", task_id, is_img2img=is_img2img, is_ui=True, args=params
)
self.__total_pending_tasks += 1
def register_api_task(self, task_id: str, is_img2img: bool, args: dict):
progress.add_task_to_queue(task_id)
params = self.__serialize_api_task_args(is_img2img, **args)
task_type = "img2img" if is_img2img else "txt2img"
task_manager.add_task(Task(id=task_id, type=task_type, params=params))
self.__run_callbacks(
"task_registered", task_id, is_img2img=is_img2img, is_ui=False, args=params
)
self.__total_pending_tasks += 1
def execute_task(self, task: Task, get_next_task: Callable):
while True:
if self.dispose:
sys.exit(0)
if progress.current_task is None:
task_id = task.id
is_img2img = task.type == "img2img"
log.info(f"[AgentScheduler] Executing task {task_id}")
task_args = self.parse_task_args(
task.params,
task.script_params,
)
task_meta = {"is_img2img": is_img2img, "is_ui": task_args["is_ui"]}
self.__saved_images_path = []
self.__run_callbacks("task_started", task_id, **task_meta)
res = self.__execute_task(task_id, is_img2img, task_args)
if not res or isinstance(res, Exception):
task_manager.update_task(id=task_id, status=TaskStatus.FAILED)
self.__run_callbacks(
"task_finished", task_id, status=TaskStatus.FAILED, **task_meta
)
else:
res = json.loads(res)
log.info(f"\n[AgentScheduler] Task {task.id} done")
infotexts = []
for line in res["infotexts"]:
infotexts.extend(line.split("\n"))
infotexts[0] = f"Prompt: {infotexts[0]}"
log.info("\n".join(["** " + text for text in infotexts]))
result = {
"images": self.__saved_images_path.copy(),
"infotexts": infotexts,
}
task_manager.update_task(
id=task_id,
status=TaskStatus.DONE,
result=json.dumps(result),
)
self.__run_callbacks(
"task_finished",
task_id,
status=TaskStatus.DONE,
result=result,
**task_meta,
)
self.__saved_images_path = []
else:
time.sleep(2)
continue
task = get_next_task()
if not task:
sys.exit(0)
def execute_pending_tasks_threading(self):
if self.paused:
log.info("[AgentScheduler] Runner is paused")
return
if self.is_executing_task:
log.info("[AgentScheduler] Runner already started")
return
pending_task = self.__get_pending_task()
if pending_task:
# Start the infinite loop in a separate thread
self.__current_thread = threading.Thread(
target=self.execute_task,
args=(
pending_task,
self.__get_pending_task,
),
)
self.__current_thread.daemon = True
self.__current_thread.start()
def get_task_info(self, task: Task) -> list[Any]:
task_args = self.parse_task_args(
task.params,
task.script_params,
)
return [
task.id,
task.type,
json.dumps(task_args["named_args"]),
task.created_at.strftime("%Y-%m-%d %H:%M:%S"),
]
def __execute_task(self, task_id: str, is_img2img: bool, task_args: dict):
if task_args["is_ui"]:
return self.__execute_ui_task(task_id, is_img2img, *task_args["args"])
else:
return self.__execute_api_task(
task_id,
is_img2img,
script_args=task_args["script_args"],
**task_args["named_args"],
)
def __execute_ui_task(self, task_id: str, is_img2img: bool, *args):
func = wrap_gradio_call(img2img if is_img2img else txt2img, add_stats=True)
with queue_lock:
shared.state.begin()
progress.start_task(task_id)
res = None
try:
result = func(*args)
res = result[1]
except Exception as e:
log.error(f"[AgentScheduler] Task {task_id} failed: {e}")
log.error(traceback.format_exc())
res = e
finally:
progress.finish_task(task_id)
shared.state.end()
return res
def __execute_api_task(self, task_id: str, is_img2img: bool, **kwargs):
progress.start_task(task_id)
res = None
try:
result = (
self.__api.img2imgapi(StableDiffusionImg2ImgProcessingAPI(**kwargs))
if is_img2img
else self.__api.text2imgapi(
StableDiffusionTxt2ImgProcessingAPI(**kwargs)
)
)
res = result.info
except Exception as e:
log.error(f"[AgentScheduler] Task {task_id} failed: {e}")
log.error(traceback.format_exc())
res = e
finally:
progress.finish_task(task_id)
return res
def __get_pending_task(self):
if self.dispose:
return None
# delete task that are 7 days old
task_manager.delete_tasks_before(datetime.now() - timedelta(days=7))
self.__total_pending_tasks = task_manager.count_tasks(status="pending")
# get more task if needed
if self.__total_pending_tasks > 0:
log.info(
f"[AgentScheduler] Total pending tasks: {self.__total_pending_tasks}"
)
pending_tasks = task_manager.get_tasks(status="pending", limit=1)
if len(pending_tasks) > 0:
return pending_tasks[0]
else:
log.info("[AgentScheduler] Task queue is empty")
self.__run_callbacks("task_cleared")
def __on_image_saved(self, data: script_callbacks.ImageSaveParams):
self.__saved_images_path.append(data.filename)
def on_task_registered(self, callback: Callable):
"""Callback when a task is registered
Callback signature: callback(task_id: str, is_img2img: bool, is_ui: bool, args: dict)
"""
self.script_callbacks["task_registered"].append(callback)
def on_task_started(self, callback: Callable):
"""Callback when a task is started
Callback signature: callback(task_id: str, is_img2img: bool, is_ui: bool)
"""
self.script_callbacks["task_started"].append(callback)
def on_task_finished(self, callback: Callable):
"""Callback when a task is finished
Callback signature: callback(task_id: str, is_img2img: bool, is_ui: bool, status: TaskStatus, result: dict)
"""
self.script_callbacks["task_finished"].append(callback)
def on_task_cleared(self, callback: Callable):
self.script_callbacks["task_cleared"].append(callback)
def __run_callbacks(self, name: str, *args, **kwargs):
for callback in self.script_callbacks[name]:
callback(*args, **kwargs)
def get_instance(block) -> TaskRunner:
if TaskRunner.instance is None:
txt2img_submit_button = get_component_by_elem_id(block, "txt2img_generate")
UiControlNetUnit = detect_control_net(block, txt2img_submit_button)
TaskRunner(UiControlNetUnit)
def on_before_reload():
# Tell old instance to stop
TaskRunner.instance.dispose = True
# force recreate the instance
TaskRunner.instance = None
script_callbacks.on_before_reload(on_before_reload)
return TaskRunner.instance

205
scripts/task_scheduler.py Normal file
View File

@ -0,0 +1,205 @@
import gradio as gr
from modules import shared, script_callbacks, scripts
from modules.shared import list_checkpoint_tiles, refresh_checkpoints
from modules.ui import create_refresh_button
from scripts.task_runner import TaskRunner, get_instance
from scripts.helpers import compare_components_with_ids, get_components_by_ids
from scripts.db import init, state_manager, AppStateKey
task_runner: TaskRunner = None
initialized = False
checkpoint_current = "Current Checkpoint"
checkpoint_runtime = "Runtime Checkpoint"
class Script(scripts.Script):
def __init__(self):
super().__init__()
script_callbacks.on_app_started(lambda block, _: self.on_app_started(block))
self.checkpoint_override = checkpoint_current
def title(self):
return "Agent Scheduler"
def show(self, is_img2img):
return True
def on_checkpoint_changed(self, checkpoint):
self.checkpoint_override = checkpoint
def after_component(self, component, **_kwargs):
elem_id = "txt2img_generate" if self.is_txt2img else "img2img_generate"
if component.elem_id == elem_id:
self.generate_button = component
def on_app_started(self, block):
self.add_enqueue_button(block, self.generate_button)
def add_enqueue_button(self, root: gr.Blocks, generate: gr.Button):
is_img2img = self.is_img2img
dependencies: list[dict] = [
x
for x in root.dependencies
if x["trigger"] == "click" and generate._id in x["targets"]
]
dependency: dict = None
cnet_dependency: dict = None
UiControlNetUnit = None
for d in dependencies:
if len(d["outputs"]) == 1:
outputs = get_components_by_ids(root, d["outputs"])
output = outputs[0]
if (
isinstance(output, gr.State)
and type(output.value).__name__ == "UiControlNetUnit"
):
cnet_dependency = d
UiControlNetUnit = type(output.value)
elif len(d["outputs"]) == 4:
dependency = d
fn_block = next(
fn
for fn in root.fns
if compare_components_with_ids(fn.inputs, dependency["inputs"])
)
fn = self.wrap_register_ui_task()
args = dict(
fn=fn,
_js="submit_enqueue_img2img" if is_img2img else "submit_enqueue",
inputs=fn_block.inputs,
outputs=fn_block.outputs,
show_progress=False,
)
with root:
with generate.parent:
id_part = "img2img" if is_img2img else "txt2img"
with gr.Row(elem_id=f"{id_part}_enqueue_wrapper"):
checkpoint = gr.Dropdown(
choices=get_checkpoint_choices(),
value=checkpoint_current,
show_label=False,
interactive=True,
)
create_refresh_button(
checkpoint,
refresh_checkpoints,
lambda: {"choices": get_checkpoint_choices()},
f"refresh_{id_part}_checkpoint",
)
submit = gr.Button(
"Enqueue", elem_id=f"{id_part}_enqueue", variant="primary"
)
checkpoint.change(fn=self.on_checkpoint_changed, inputs=[checkpoint])
submit.click(**args)
if cnet_dependency is not None:
cnet_fn_block = next(
fn
for fn in root.fns
if compare_components_with_ids(fn.inputs, cnet_dependency["inputs"])
)
with root:
submit.click(
fn=UiControlNetUnit,
inputs=cnet_fn_block.inputs,
outputs=cnet_fn_block.outputs,
queue=False,
)
def wrap_register_ui_task(self):
def f(*args, **kwargs):
if len(args) == 0 and len(kwargs) == 0:
raise Exception("Invalid call")
if len(args) > 0 and type(args[0]) == str:
task_id = args[0]
else:
# not a task, exit
return (None, "", "<p>Invalid params</p>", "")
checkpoint = None
if self.checkpoint_override == checkpoint_current:
checkpoint = shared.sd_model.sd_checkpoint_info.title
elif self.checkpoint_override != checkpoint_runtime:
checkpoint = self.checkpoint_override
task_runner.register_ui_task(
task_id, self.is_img2img, *args, checkpoint=checkpoint
)
task_runner.execute_pending_tasks_threading()
return (None, "", "<p>Task queued</p>", "")
return f
def get_checkpoint_choices():
choices = [checkpoint_current, checkpoint_runtime]
choices.extend(list_checkpoint_tiles())
return choices
def is_queue_paused():
return state_manager.get_value(AppStateKey.QueueState) == "paused"
def on_ui_tab(**_kwargs):
global initialized
if not initialized:
initialized = True
init()
with gr.Blocks(analytics_enabled=False) as scheduler_tab:
with gr.Row(elem_id="agent_scheduler_pending_tasks_wrapper"):
with gr.Column(scale=1):
with gr.Group(elem_id="agent_scheduler_actions"):
paused = is_queue_paused()
pause = gr.Button(
"Pause",
elem_id="agent_scheduler_action_pause",
variant="stop",
visible=not paused,
)
resume = gr.Button(
"Resume",
elem_id="agent_scheduler_action_resume",
variant="primary",
visible=paused,
)
gr.Button(
"Refresh",
elem_id="agent_scheduler_action_refresh",
variant="secondary",
)
gr.HTML('<div id="agent_scheduler_action_search"></div>')
gr.HTML(
'<div id="agent_scheduler_pending_tasks_grid" class="ag-theme-alpine"></div>'
)
with gr.Column(scale=1):
with gr.Group(elem_id="agent_scheduler_current_task_progress"):
gr.Gallery(
elem_id="agent_scheduler_current_task_images",
label="Output",
show_label=False,
).style(grid=4)
return [(scheduler_tab, "Agent Scheduler", "agent_scheduler")]
def on_app_started(block, _):
global task_runner
task_runner = get_instance(block)
task_runner.execute_pending_tasks_threading()
script_callbacks.on_ui_tabs(on_ui_tab)
script_callbacks.on_app_started(on_app_started)

430
style.css Normal file
View File

@ -0,0 +1,430 @@
.ts-search {
position: relative;
margin-left: auto;
width: 100%;
max-width: 20rem;
}
.ts-search-input {
display: block;
width: 100%;
border-radius: 0.375rem !important;
border-width: 1px;
--tw-border-opacity: 1;
border-color: rgb(209 213 219 / var(--tw-border-opacity));
--tw-bg-opacity: 1;
background-color: rgb(249 250 251 / var(--tw-bg-opacity));
padding: 0.5rem !important;
padding-left: 2.5rem !important;
font-size: 0.875rem;
line-height: 1.25rem;
--tw-text-opacity: 1;
color: rgb(17 24 39 / var(--tw-text-opacity));
}
.ts-search-input:focus {
--tw-border-opacity: 1;
border-color: rgb(59 130 246 / var(--tw-border-opacity));
--tw-ring-opacity: 1;
--tw-ring-color: rgb(59 130 246 / var(--tw-ring-opacity));
}
:is(.dark .ts-search-input) {
--tw-border-opacity: 1 !important;
border-color: rgb(75 85 99 / var(--tw-border-opacity)) !important;
--tw-bg-opacity: 1 !important;
background-color: rgb(55 65 81 / var(--tw-bg-opacity)) !important;
--tw-text-opacity: 1 !important;
color: rgb(255 255 255 / var(--tw-text-opacity)) !important;
}
:is(.dark .ts-search-input)::-moz-placeholder {
--tw-placeholder-opacity: 1 !important;
color: rgb(156 163 175 / var(--tw-placeholder-opacity)) !important;
}
:is(.dark .ts-search-input)::placeholder {
--tw-placeholder-opacity: 1 !important;
color: rgb(156 163 175 / var(--tw-placeholder-opacity)) !important;
}
:is(.dark .ts-search-input:focus) {
--tw-border-opacity: 1;
border-color: rgb(59 130 246 / var(--tw-border-opacity));
--tw-ring-opacity: 1;
--tw-ring-color: rgb(59 130 246 / var(--tw-ring-opacity));
}
.ts-search-icon {
pointer-events: none;
position: absolute;
top: 0px;
bottom: 0px;
left: 0px;
display: flex;
align-items: center;
padding-left: 0.75rem;
}
:is(.dark .ts-search-icon) {
--tw-text-opacity: 1;
color: rgb(255 255 255 / var(--tw-text-opacity));
}
.ts-btn-action {
margin: 0px !important;
display: inline-flex;
align-items: center;
border-width: 1px;
padding-left: 0.5rem !important;
padding-right: 0.5rem !important;
padding-top: 0.25rem !important;
padding-bottom: 0.25rem !important;
font-size: 0.875rem;
line-height: 1.25rem;
font-weight: 500;
}
.ts-btn-action:focus {
z-index: 10;
--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width)
var(--tw-ring-offset-color);
--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(2px + var(--tw-ring-offset-width))
var(--tw-ring-color);
box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
}
.ts-btn-action:disabled {
cursor: not-allowed;
opacity: 0.5;
}
.ts-btn-action:hover:disabled {
background-color: transparent !important;
}
.ts-btn-run {
border-top-left-radius: 0.375rem;
border-bottom-left-radius: 0.375rem;
--tw-border-opacity: 1;
border-color: rgb(34 197 94 / var(--tw-border-opacity));
--tw-text-opacity: 1 !important;
color: rgb(34 197 94 / var(--tw-text-opacity)) !important;
}
.ts-btn-run:hover {
--tw-bg-opacity: 1;
background-color: rgb(22 163 74 / var(--tw-bg-opacity));
--tw-text-opacity: 1 !important;
color: rgb(255 255 255 / var(--tw-text-opacity)) !important;
}
.ts-btn-run:focus {
--tw-ring-opacity: 1;
--tw-ring-color: rgb(74 222 128 / var(--tw-ring-opacity));
}
.ts-btn-run:hover:disabled {
--tw-text-opacity: 1 !important;
color: rgb(34 197 94 / var(--tw-text-opacity)) !important;
}
:is(.dark .ts-btn-run) {
--tw-border-opacity: 1;
border-color: rgb(34 197 94 / var(--tw-border-opacity));
}
:is(.dark .ts-btn-run:hover) {
--tw-bg-opacity: 1;
background-color: rgb(22 163 74 / var(--tw-bg-opacity));
}
:is(.dark .ts-btn-run:focus) {
--tw-ring-opacity: 1;
--tw-ring-color: rgb(20 83 45 / var(--tw-ring-opacity));
}
.ts-btn-delete {
border-top-right-radius: 0.375rem;
border-bottom-right-radius: 0.375rem;
--tw-border-opacity: 1;
border-color: rgb(220 38 38 / var(--tw-border-opacity));
--tw-text-opacity: 1 !important;
color: rgb(239 68 68 / var(--tw-text-opacity)) !important;
}
.ts-btn-delete:hover {
--tw-bg-opacity: 1;
background-color: rgb(220 38 38 / var(--tw-bg-opacity));
--tw-text-opacity: 1 !important;
color: rgb(255 255 255 / var(--tw-text-opacity)) !important;
}
.ts-btn-delete:focus {
--tw-ring-opacity: 1;
--tw-ring-color: rgb(252 165 165 / var(--tw-ring-opacity));
}
:is(.dark .ts-btn-delete) {
--tw-border-opacity: 1;
border-color: rgb(239 68 68 / var(--tw-border-opacity));
}
:is(.dark .ts-btn-delete:hover) {
--tw-bg-opacity: 1;
background-color: rgb(220 38 38 / var(--tw-bg-opacity));
}
:is(.dark .ts-btn-delete:focus) {
--tw-ring-opacity: 1;
--tw-ring-color: rgb(127 29 29 / var(--tw-ring-opacity));
}
.mt-1 {
margin-top: 0.25rem;
}
.mt-1\.5 {
margin-top: 0.375rem;
}
.inline-flex {
display: inline-flex;
}
.grid {
display: grid;
}
.rounded-md {
border-radius: 0.375rem;
}
.shadow-sm {
--tw-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05);
--tw-shadow-colored: 0 1px 2px 0 var(--tw-shadow-color);
box-shadow: var(--tw-ring-offset-shadow, 0 0 #0000), var(--tw-ring-shadow, 0 0 #0000),
var(--tw-shadow);
}
.filter {
filter: var(--tw-blur) var(--tw-brightness) var(--tw-contrast) var(--tw-grayscale)
var(--tw-hue-rotate) var(--tw-invert) var(--tw-saturate) var(--tw-sepia) var(--tw-drop-shadow);
}
/****************************************************************/
#agent_scheduler_pending_tasks_wrapper {
gap: var(--layout-gap);
border: none;
box-shadow: none;
border-width: 0;
}
@media (max-width: 1024px) {
#agent_scheduler_pending_tasks_wrapper {
flex-wrap: wrap;
}
}
#agent_scheduler_pending_tasks_wrapper > div:last-child {
width: 100%;
max-width: 512px;
}
#agent_scheduler_current_task_images {
width: 100%;
padding-top: 100%;
position: relative;
}
#agent_scheduler_current_task_images > div {
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
height: 100%;
}
#agent_scheduler_pending_tasks_wrapper {
justify-content: flex-end;
gap: var(--layout-gap);
padding: 0 var(--layout-gap) var(--layout-gap) var(--layout-gap);
}
#agent_scheduler_pending_tasks_wrapper > button {
flex: 0 0 auto;
}
#agent_scheduler_actions {
display: flex;
gap: var(--layout-gap);
}
#agent_scheduler_actions > button {
border-radius: var(--radius-lg) !important;
}
@keyframes blink {
from,
to {
opacity: 0;
}
50% {
opacity: 1;
}
}
@-moz-keyframes blink {
from,
to {
opacity: 0;
}
50% {
opacity: 1;
}
}
@-webkit-keyframes blink {
from,
to {
opacity: 0;
}
50% {
opacity: 1;
}
}
@-ms-keyframes blink {
from,
to {
opacity: 0;
}
50% {
opacity: 1;
}
}
@-o-keyframes blink {
from,
to {
opacity: 0;
}
50% {
opacity: 1;
}
}
.ag-theme-alpine,
.ag-theme-alpine-dark {
--ag-row-height: 45px;
--ag-header-height: 45px;
/* --ag-grid-size: 6px; */
--ag-cell-horizontal-padding: calc(var(--ag-grid-size) * 2);
--body-text-color: 'inherit';
}
.task-running {
color: #52c41a !important;
-webkit-animation: 1s blink ease infinite;
-moz-animation: 1s blink ease infinite;
-ms-animation: 1s blink ease infinite;
-o-animation: 1s blink ease infinite;
animation: 1s blink ease infinite;
}
.generate-box {
gap: unset !important;
}
#txt2img_generate,
#img2img_generate {
min-height: unset !important;
}
.generate-box #txt2img_interrupt {
position: initial !important;
height: 42px;
}
.generate-box #txt2img_interrupt,
.generate-box #img2img_interrupt,
.generate-box #txt2img_skip,
.generate-box #img2img_skip {
position: initial !important;
height: 42px;
margin-top: 0 !important;
display: none !important;
max-width: 50% !important;
}
.black-orange .generate-box #txt2img_interrupt,
.black-orange .generate-box #img2img_interrupt,
.black-orange .generate-box #txt2img_skip,
.black-orange .generate-box #img2img_skip {
height: 36px !important;
}
#txt2img_enqueue_wrapper,
#img2img_enqueue_wrapper {
flex-wrap: nowrap;
margin-top: calc(var(--layout-gap) / 2);
min-width: 100%;
gap: calc(var(--layout-gap) / 2);
}
#txt2img_enqueue_wrapper > div:first-child,
#img2img_enqueue_wrapper > div:first-child {
flex: 1 1 auto;
}
#txt2img_enqueue_wrapper > .gradio-button.primary,
#img2img_enqueue_wrapper > .gradio-button.primary {
flex: 0 0 auto;
min-width: 0;
}
.black-orange #txt2img_enqueue_wrapper .gradio-button,
.black-orange #img2img_enqueue_wrapper .gradio-button,
.black-orange #txt2img_enqueue_wrapper .gradio-dropdown .wrap-inner,
.black-orange #img2img_enqueue_wrapper .gradio-dropdown .wrap-inner {
height: 36px;
padding: 6px 12px;
}
.black-orange #txt2img_tools,
.black-orange #img2img_tools {
margin-top: 0;
margin-left: 0;
scale: 1;
}
.black-orange #txt2img_tools .gradio-button,
.black-orange #img2img_tools .gradio-button {
min-width: 36px !important;
height: 36px;
}
.black-orange #txt2img_actions_column,
.black-orange #img2img_actions_column {
min-width: min(320px, 100%) !important;
}
.generate-box #txt2img_interrupt[style='display: block;'],
.generate-box #img2img_interrupt[style='display: block;'],
.generate-box #txt2img_skip[style='display: block;'],
.generate-box #img2img_skip[style='display: block;'] {
display: block !important;
}
.generate-box #txt2img_skip[style='display: block;'] + #txt2img_generate,
.generate-box #img2img_skip[style='display: block;'] + #img2img_generate {
display: none !important;
}
#agent_scheduler_current_task_progress .livePreview {
margin: 0;
}