feature update

* Add api to download task's generated images
* Add setting to render extension UI below the main UI
* Display task datetime in local timezone
* Persist the grid state (columns order, sorting) for next session
* Bugs fixing and code improvements
pull/71/head
Tung Nguyen 2023-06-20 22:01:47 +07:00
parent 45dfe5977e
commit ce2b9b4eb9
19 changed files with 3435 additions and 3036 deletions

View File

@ -103,7 +103,9 @@ By default, queued tasks use the currently loaded checkpoint. However, changing
All the functionality of this extension can be accessed through HTTP APIs. You can access the API documentation via `http://127.0.0.1:7860/docs`. Remember to include `--api` in your startup arguments.
![API docs](https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/a1b1f2a6-b631-4a59-b904-d6eb3aa90b7d)
![API docs](https://github.com/ArtVentureX/sd-webui-agent-scheduler/assets/133728487/012ab2cc-b41f-4c68-8fa5-7ab4e49aa91d)
#### Queue Task
The two apis `/agent-scheduler/v1/queue/txt2img` and `/agent-scheduler/v1/queue/img2img` support all the parameters of the original webui apis. These apis response the task id, which can be used to perform updates later.
@ -113,6 +115,30 @@ The two apis `/agent-scheduler/v1/queue/txt2img` and `/agent-scheduler/v1/queue/
}
```
#### Download Results
Use api `/agent-scheduler/v1/results/{id}` to get the generated images. The api supports two response format:
- json with base64 encoded
```json
{
"success": true,
"data": [
{
"image": "data:image/png;base64,iVBORw0KGgoAAAAN...",
"infotext": "1girl\nNegative prompt: EasyNegative, badhandv4..."
},
{
"image": "data:image/png;base64,iVBORw0KGgoAAAAN...",
"infotext": "1girl\nNegative prompt: EasyNegative, badhandv4..."
}
]
}
```
- zip file with querystring `zip=true`
## Road Map
To list possible feature upgrades for this extension

View File

@ -1,6 +1,12 @@
import io
import json
import threading
from uuid import uuid4
from zipfile import ZipFile
from pathlib import Path
from typing import Optional
from gradio.routes import App
from fastapi.responses import StreamingResponse
from modules import shared, progress
@ -15,7 +21,7 @@ from .models import (
)
from .task_runner import TaskRunner
from .helpers import log
from .task_helpers import serialize_api_task_args
from .task_helpers import encode_image_to_base64
def regsiter_apis(app: App, task_runner: TaskRunner):
@ -23,16 +29,15 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
@app.post("/agent-scheduler/v1/queue/txt2img", response_model=QueueTaskResponse)
def queue_txt2img(body: Txt2ImgApiTaskArgs):
params = body.dict()
task_id = str(uuid4())
checkpoint = params.pop("model_hash", None)
task_args = serialize_api_task_args(
params,
is_img2img=False,
checkpoint=checkpoint,
)
args = body.dict()
checkpoint = args.pop("model_hash", None)
task_runner.register_api_task(
task_id, api_task_id=False, is_img2img=False, args=task_args
task_id,
api_task_id=None,
is_img2img=False,
args=args,
checkpoint=checkpoint,
)
task_runner.execute_pending_tasks_threading()
@ -40,21 +45,30 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
@app.post("/agent-scheduler/v1/queue/img2img", response_model=QueueTaskResponse)
def queue_img2img(body: Img2ImgApiTaskArgs):
params = body.dict()
task_id = str(uuid4())
checkpoint = params.pop("model_hash", None)
task_args = serialize_api_task_args(
params,
is_img2img=True,
checkpoint=checkpoint,
)
args = body.dict()
checkpoint = args.pop("model_hash", None)
task_runner.register_api_task(
task_id, api_task_id=False, is_img2img=True, args=task_args
task_id,
api_task_id=None,
is_img2img=True,
args=args,
checkpoint=checkpoint,
)
task_runner.execute_pending_tasks_threading()
return QueueTaskResponse(task_id=task_id)
def format_task_args(task):
task_args = TaskRunner.instance.parse_task_args(task, deserialization=False)
named_args = task_args.named_args
named_args["checkpoint"] = task_args.checkpoint
# remove unused args to reduce payload size
named_args.pop("alwayson_scripts", None)
named_args.pop("script_args", None)
named_args.pop("init_images", None)
return named_args
@app.get("/agent-scheduler/v1/queue", response_model=QueueStatusResponse)
def queue_status_api(limit: int = 20, offset: int = 0):
current_task_id = progress.current_task
@ -64,14 +78,9 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
)
parsed_tasks = []
for task in pending_tasks:
task_args = TaskRunner.instance.parse_task_args(
task.params, task.script_params, deserialization=False
)
named_args = task_args.named_args
named_args["checkpoint"] = task_args.checkpoint
params = format_task_args(task)
task_data = task.dict()
task_data["params"] = named_args
task_data["params"] = params
if task.id == current_task_id:
task_data["status"] = "running"
@ -104,14 +113,9 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
)
parsed_tasks = []
for task in tasks:
task_args = TaskRunner.instance.parse_task_args(
task.params, task.script_params, deserialization=False
)
named_args = task_args.named_args
named_args["checkpoint"] = task_args.checkpoint
params = format_task_args(task)
task_data = task.dict()
task_data["params"] = named_args
task_data["params"] = params
parsed_tasks.append(TaskModel(**task_data))
return HistoryResponse(
@ -220,6 +224,50 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
task_manager.update_task(id, name=name)
return {"success": True, "message": f"Task {id} is renamed"}
@app.get("/agent-scheduler/v1/results/{id}")
def get_task_results(id: str, zip: Optional[bool] = False):
task = task_manager.get_task(id)
if task is None:
return {"success": False, "message": f"Task not found"}
if task.status != TaskStatus.DONE:
return {"success": False, "message": f"Task is {task.status.value}"}
if task.result is None:
return {"success": False, "message": f"Task result is not available"}
result: dict = json.loads(task.result)
infotexts = result["infotexts"]
if zip:
zip_buffer = io.BytesIO()
# Create a new zip file in the in-memory buffer
with ZipFile(zip_buffer, "w") as zip_file:
# Loop through the files in the directory and add them to the zip file
for image in result["images"]:
if Path(image).is_file():
zip_file.write(Path(image), Path(image).name)
# Reset the buffer position to the beginning to avoid truncation issues
zip_buffer.seek(0)
# Return the in-memory buffer as a streaming response with the appropriate headers
return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={
"Content-Disposition": f"attachment; filename=results-{id}.zip"
},
)
else:
data = [
{"image": encode_image_to_base64(image), "infotext": infotexts[i]}
for i, image in enumerate(result["images"])
]
return {"success": True, "data": data}
@app.post("/agent-scheduler/v1/pause")
def pause_queue():
shared.opts.queue_paused = True

View File

@ -31,35 +31,8 @@ class Task(TaskModel):
script_params: bytes = None
params: str
def __init__(
self,
id: str = "",
api_task_id: str = None,
name: str = None,
type: str = "unknown",
params: str = "",
priority: int = None,
status: str = TaskStatus.PENDING.value,
result: str = None,
bookmarked: bool = False,
created_at: Optional[datetime] = None,
updated_at: Optional[datetime] = None,
):
priority = priority if priority else int(datetime.utcnow().timestamp() * 1000)
super().__init__(
id=id,
api_task_id=api_task_id,
name=name,
type=type,
params=params,
status=status,
priority=priority,
result=result,
bookmarked=bookmarked,
created_at=created_at,
updated_at=updated_at,
)
def __init__(self, priority=int(datetime.utcnow().timestamp() * 1000), **kwargs):
super().__init__(priority=priority, **kwargs)
class Config(TaskModel.__config__):
exclude = ["script_params"]
@ -104,7 +77,7 @@ class TaskTable(Base):
type = Column(String(20), nullable=False) # txt2img or img2txt
params = Column(Text, nullable=False) # task args
script_params = Column(LargeBinary, nullable=False) # script args
priority = Column(Integer, nullable=False, default=datetime.now)
priority = Column(Integer, nullable=False)
status = Column(
String(20), nullable=False, default="pending"
) # pending, running, done, failed
@ -144,6 +117,7 @@ class TaskManager(BaseTableManager):
type: str = None,
status: Union[str, list[str]] = None,
bookmarked: bool = None,
api_task_id: str = None,
limit: int = None,
offset: int = None,
order: str = "asc",
@ -160,6 +134,9 @@ class TaskManager(BaseTableManager):
else:
query = query.filter(TaskTable.status == status)
if api_task_id:
query = query.filter(TaskTable.api_task_id == api_task_id)
if bookmarked == True:
query = query.filter(TaskTable.bookmarked == bookmarked)
else:
@ -189,6 +166,7 @@ class TaskManager(BaseTableManager):
self,
type: str = None,
status: Union[str, list[str]] = None,
api_task_id: str = None,
) -> int:
session = Session(self.engine)
try:
@ -202,6 +180,9 @@ class TaskManager(BaseTableManager):
else:
query = query.filter(TaskTable.status == status)
if api_task_id:
query = query.filter(TaskTable.api_task_id == api_task_id)
return query.count()
except Exception as e:
print(f"Exception counting tasks from database: {e}")
@ -262,7 +243,7 @@ class TaskManager(BaseTableManager):
result = session.get(TaskTable, id)
if result:
if priority == 0:
result.priority = self.__get_min_priority() - 1
result.priority = self.__get_min_priority(status=TaskStatus.PENDING) - 1
elif priority == -1:
result.priority = int(datetime.utcnow().timestamp() * 1000)
else:
@ -316,10 +297,14 @@ class TaskManager(BaseTableManager):
finally:
session.close()
def __get_min_priority(self) -> int:
def __get_min_priority(self, status: str = None) -> int:
session = Session(self.engine)
try:
min_priority = session.query(func.min(TaskTable.priority)).scalar()
query = session.query(func.min(TaskTable.priority))
if status is not None:
query = query.filter(TaskTable.status == status)
min_priority = query.scalar()
return min_priority if min_priority else 0
except Exception as e:
print(f"Exception getting min priority from database: {e}")

View File

@ -6,7 +6,7 @@ from gradio.blocks import Block, BlockContext
if not logging.getLogger().hasHandlers():
# Logging is not set up
logging.basicConfig(level=logging.INFO, format='%(message)s')
logging.basicConfig(level=logging.INFO, format="%(message)s")
log = logging.getLogger("sd")
@ -78,3 +78,43 @@ def detect_control_net(root: gr.Blocks, submit: gr.Button):
UiControlNetUnit = type(output.value)
return UiControlNetUnit
def get_dict_attribute(dict_inst: dict, name_string: str, default=None):
nested_keys = name_string.split(".")
value = dict_inst
for key in nested_keys:
value = value.get(key, None)
if value is None:
return default
return value
def set_dict_attribute(dict_inst: dict, name_string: str, value):
"""
Set an attribute to a dictionary using dot notation.
If the attribute does not already exist, it will create a nested dictionary.
Parameters:
- dict_inst: the dictionary instance to set the attribute
- name_string: the attribute name in dot notation (ex: 'attribute.name')
- value: the value to set for the attribute
Returns:
None
"""
# Split the attribute names by dot
name_list = name_string.split(".")
# Traverse the dictionary and create a nested dictionary if necessary
current_dict = dict_inst
for name in name_list[:-1]:
if name not in current_dict:
current_dict[name] = {}
current_dict = current_dict[name]
# Set the final attribute to its value
current_dict[name_list[-1]] = value

View File

@ -32,12 +32,14 @@ class TaskModel(BaseModel):
name: Optional[str] = Field(title="Task Name")
type: str = Field(title="Task Type", description="Either txt2img or img2img")
status: str = Field(
title="Task Status", description="Either pending, running, done or failed"
"pending",
title="Task Status",
description="Either pending, running, done or failed",
)
params: dict[str, Any] = Field(
title="Task Parameters", description="The parameters of the task in JSON format"
)
priority: int = Field(title="Task Priority")
priority: Optional[int] = Field(title="Task Priority")
result: Optional[str] = Field(
title="Task Result", description="The result of the task in JSON format"
)
@ -53,12 +55,6 @@ class TaskModel(BaseModel):
default=None,
)
class Config:
json_encoders = {
# custom output conversion for datetime
datetime: convert_datetime_to_iso_8601_with_z_suffix
}
class Txt2ImgApiTaskArgs(StableDiffusionTxt2ImgProcessingAPI):
checkpoint: Optional[str] = Field(
@ -67,9 +63,7 @@ class Txt2ImgApiTaskArgs(StableDiffusionTxt2ImgProcessingAPI):
description="Custom checkpoint hash. If not specified, the latest checkpoint will be used.",
)
sampler_index: Optional[str] = Field(
sd_samplers.samplers[0].name,
title="Sampler name",
alias="sampler_name"
sd_samplers.samplers[0].name, title="Sampler name", alias="sampler_name"
)
class Config(StableDiffusionTxt2ImgProcessingAPI.__config__):
@ -87,9 +81,7 @@ class Img2ImgApiTaskArgs(StableDiffusionImg2ImgProcessingAPI):
description="Custom checkpoint hash. If not specified, the latest checkpoint will be used.",
)
sampler_index: Optional[str] = Field(
sd_samplers.samplers[0].name,
title="Sampler name",
alias="sampler_name"
sd_samplers.samplers[0].name, title="Sampler name", alias="sampler_name"
)
class Config(StableDiffusionImg2ImgProcessingAPI.__config__):
@ -116,7 +108,13 @@ class QueueStatusResponse(BaseModel):
)
paused: bool = Field(title="Paused", description="Whether the queue is paused")
class Config:
json_encoders = {datetime: lambda dt: int(dt.timestamp() * 1e3)}
class HistoryResponse(BaseModel):
tasks: List[TaskModel] = Field(title="Tasks")
total: int = Field(title="Task count")
class Config:
json_encoders = {datetime: lambda dt: int(dt.timestamp() * 1e3)}

View File

@ -5,6 +5,7 @@ import base64
import inspect
import requests
import numpy as np
from typing import Union
from enum import Enum
from PIL import Image, ImageOps, ImageChops, ImageEnhance, ImageFilter
@ -18,7 +19,7 @@ from modules.api.models import (
StableDiffusionImg2ImgProcessingAPI,
)
from .helpers import log
from .helpers import log, get_dict_attribute
img2img_image_args_by_mode: dict[int, list[list[str]]] = {
0: [["init_img"]],
@ -29,6 +30,22 @@ img2img_image_args_by_mode: dict[int, list[list[str]]] = {
}
def get_script_by_name(
script_name: str, is_img2img: bool = False, is_always_on: bool = False
) -> scripts.Script:
script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
available_scripts = (
script_runner.alwayson_scripts
if is_always_on
else script_runner.selectable_scripts
)
return next(
(s for s in available_scripts if s.title().lower() == script_name.lower()),
None,
)
def load_image_from_url(url: str):
try:
response = requests.get(url)
@ -39,43 +56,20 @@ def load_image_from_url(url: str):
return None
def load_image(image: str):
if not isinstance(image, str):
def encode_image_to_base64(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype("uint8"))
elif isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image = load_image_from_url(image)
if not isinstance(image, Image.Image):
return image
pil_image = None
if os.path.exists(image):
pil_image = Image.open(image)
elif image.startswith(("http://", "https://")):
pil_image = load_image_from_url(image)
return pil_image
def load_image_to_nparray(image: str):
pil_image = load_image(image)
return (
np.array(pil_image).astype("uint8")
if isinstance(pil_image, Image.Image)
else None
)
def encode_pil_to_base64(image: Image.Image):
with io.BytesIO() as output_bytes:
image.save(output_bytes, format="PNG")
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data).decode("utf-8")
def load_image_to_base64(image: str):
pil_image = load_image(image)
if not isinstance(pil_image, Image.Image):
return image
return encode_pil_to_base64(pil_image)
return "data:image/png;base64," + base64.b64encode(bytes_data).decode("utf-8")
def __serialize_image(image):
@ -164,7 +158,6 @@ def serialize_controlnet_args(cnet_unit):
def deserialize_controlnet_args(args: dict):
# args.pop("is_cnet", None)
for k, v in args.items():
if k == "image" and v is not None:
args[k] = {
@ -177,6 +170,24 @@ def deserialize_controlnet_args(args: dict):
return args
def map_controlnet_args_to_api_task_args(args: dict):
if type(args).__name__ == "UiControlNetUnit":
args = args.__dict__
for k, v in args.items():
if k == "image" and v is not None:
args[k] = {
"image": encode_image_to_base64(v["image"]),
"mask": encode_image_to_base64(v["mask"])
if v.get("mask", None) is not None
else None,
}
if isinstance(v, Enum):
args[k] = v.value
return args
def map_ui_task_args_list_to_named_args(
args: list, is_img2img: bool, checkpoint: str = None
):
@ -195,7 +206,10 @@ def map_ui_task_args_list_to_named_args(
sampler_index = named_args.get("sampler_index", None)
if sampler_index is not None:
sampler_name = sd_samplers.samplers[named_args["sampler_index"]].name
available_samplers = (
sd_samplers.samplers_for_img2img if is_img2img else sd_samplers.samplers
)
sampler_name = available_samplers[named_args["sampler_index"]].name
named_args["sampler_name"] = sampler_name
log.debug(f"serialize sampler index: {str(sampler_index)} as {sampler_name}")
@ -230,6 +244,49 @@ def map_named_args_to_ui_task_args_list(
return args
def map_script_args_list_to_named(script: scripts.Script, args: list):
script_name = script.title().lower()
print("script", script_name, "is alwayson", script.alwayson)
if script_name == "controlnet":
for i, cnet_args in enumerate(args):
args[i] = map_controlnet_args_to_api_task_args(cnet_args)
return args
fn = script.process if script.alwayson else script.run
inspection = inspect.getfullargspec(fn)
arg_names = inspection.args[2:]
named_script_args = dict(zip(arg_names, args[: len(arg_names)]))
if inspection.varargs is not None:
named_script_args[inspection.varargs] = args[len(arg_names) :]
return named_script_args
def map_named_script_args_to_list(
script: scripts.Script, named_args: Union[dict, list]
):
script_name = script.title().lower()
if isinstance(named_args, dict):
fn = script.process if script.alwayson else script.run
inspection = inspect.getfullargspec(fn)
arg_names = inspection.args[2:]
args = [named_args.get(name, None) for name in arg_names]
if inspection.varargs is not None:
args.extend(named_args.get(inspection.varargs, []))
return args
if isinstance(named_args, list):
if script_name == "controlnet":
for i, cnet_args in enumerate(named_args):
named_args[i] = map_controlnet_args_to_api_task_args(cnet_args)
return named_args
def map_ui_task_args_to_api_task_args(
named_args: dict, script_args: list, is_img2img: bool
):
@ -255,45 +312,50 @@ def map_ui_task_args_to_api_task_args(
# the logic below is copied from modules/img2img.py
if mode == 0:
image = api_task_args.pop("init_img").convert("RGB")
image = api_task_args.pop("init_img")
image = image.convert("RGB") if image else None
mask = None
elif mode == 1:
image = api_task_args.pop("sketch").convert("RGB")
image = api_task_args.pop("sketch")
image = image.convert("RGB") if image else None
mask = None
elif mode == 2:
init_img_with_mask: dict = api_task_args.pop("init_img_with_mask")
image = init_img_with_mask.get("image").convert("RGB")
mask = init_img_with_mask.get("mask")
alpha_mask = (
ImageOps.invert(image.split()[-1])
.convert("L")
.point(lambda x: 255 if x > 0 else 0, mode="1")
)
mask = ImageChops.lighter(alpha_mask, mask.convert("L")).convert("L")
init_img_with_mask: dict = api_task_args.pop("init_img_with_mask") or {}
image = init_img_with_mask.get("image", None)
image = image.convert("RGB") if image else None
mask = init_img_with_mask.get("mask", None)
if mask:
alpha_mask = (
ImageOps.invert(image.split()[-1])
.convert("L")
.point(lambda x: 255 if x > 0 else 0, mode="1")
)
mask = ImageChops.lighter(alpha_mask, mask.convert("L")).convert("L")
elif mode == 3:
image = api_task_args.pop("inpaint_color_sketch")
orig = api_task_args.pop("inpaint_color_sketch_orig") or image
mask_alpha = api_task_args.pop("mask_alpha", 0)
mask_blur = api_task_args.get("mask_blur", 4)
pred = np.any(np.array(image) != np.array(orig), axis=-1)
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
blur = ImageFilter.GaussianBlur(mask_blur)
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
image = image.convert("RGB")
if image is not None:
mask_alpha = api_task_args.pop("mask_alpha", 0)
mask_blur = api_task_args.get("mask_blur", 4)
pred = np.any(np.array(image) != np.array(orig), axis=-1)
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
blur = ImageFilter.GaussianBlur(mask_blur)
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
image = image.convert("RGB")
elif mode == 4:
image = api_task_args.pop("init_img_inpaint")
mask = api_task_args.pop("init_mask_inpaint")
else:
raise Exception(f"Batch mode is not supported yet")
image = ImageOps.exif_transpose(image)
api_task_args["init_images"] = [encode_pil_to_base64(image)]
api_task_args["mask"] = encode_pil_to_base64(mask) if mask is not None else None
image = ImageOps.exif_transpose(image) if image else None
api_task_args["init_images"] = [encode_image_to_base64(image)] if image else []
api_task_args["mask"] = encode_image_to_base64(mask) if mask else None
selected_scale_tab = api_task_args.pop("selected_scale_tab", 0)
scale_by = api_task_args.pop("scale_by", 1)
if selected_scale_tab == 1:
scale_by = api_task_args.get("scale_by", 1)
if selected_scale_tab == 1 and image:
api_task_args["width"] = int(image.width * scale_by)
api_task_args["height"] = int(image.height * scale_by)
else:
@ -311,23 +373,27 @@ def map_ui_task_args_to_api_task_args(
api_task_args["script_name"] = None
api_task_args["script_args"] = []
else:
script = script_runner.selectable_scripts[script_id - 1]
api_task_args["script_name"] = script.title.lower()
api_task_args["script_args"] = script_args[script.args_from : script.args_to]
script: scripts.Script = script_runner.selectable_scripts[script_id - 1]
api_task_args["script_name"] = script.title().lower()
current_script_args = script_args[script.args_from : script.args_to]
api_task_args["script_args"] = map_script_args_list_to_named(
script, current_script_args
)
# alwayson scripts
alwayson_scripts = api_task_args.get("alwayson_scripts", None)
if not alwayson_scripts or not isinstance(alwayson_scripts, dict):
alwayson_scripts = {}
api_task_args["alwayson_scripts"] = alwayson_scripts
if not alwayson_scripts:
api_task_args["alwayson_scripts"] = {}
alwayson_scripts = api_task_args["alwayson_scripts"]
for script in script_runner.alwayson_scripts:
alwayson_script_args = script_args[script.args_from : script.args_to]
if script.title.lower() == "controlnet":
for i, cnet_args in enumerate(alwayson_script_args):
alwayson_script_args[i] = serialize_controlnet_args(cnet_args)
alwayson_scripts[script.title.lower()] = {"args": alwayson_script_args}
script_name = script.title().lower()
if script_name != "agent scheduler":
named_script_args = map_script_args_list_to_named(
script, alwayson_script_args
)
alwayson_scripts[script_name] = {"args": named_script_args}
return api_task_args
@ -336,8 +402,33 @@ def serialize_api_task_args(
params: dict,
is_img2img: bool,
checkpoint: str = None,
controlnet_args: list[dict] = None,
):
# handle named script args
script_name = params.get("script_name", None)
if script_name is not None:
script = get_script_by_name(script_name, is_img2img)
if script is None:
raise Exception(f"Not found script {script_name}")
script_args = params.get("script_args", {})
params["script_args"] = map_named_script_args_to_list(script, script_args)
# handle named alwayson script args
alwayson_scripts = get_dict_attribute(params, "alwayson_scripts", {})
valid_alwayson_scripts = {}
script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
for script in script_runner.alwayson_scripts:
script_name = script.title().lower()
if script_name == "agent scheduler":
continue
script_args = get_dict_attribute(alwayson_scripts, f"{script_name}.args", None)
if script_args:
arg_list = map_named_script_args_to_list(script, script_args)
valid_alwayson_scripts[script_name] = {"args": arg_list}
params["alwayson_scripts"] = valid_alwayson_scripts
args = (
StableDiffusionImg2ImgProcessingAPI(**params)
if is_img2img
@ -350,48 +441,19 @@ def serialize_api_task_args(
if checkpoint is not None:
checkpoint_info: CheckpointInfo = get_closet_checkpoint_match(checkpoint)
if not checkpoint_info:
log.warn(
f"[AgentScheduler] No checkpoint found for model hash {checkpoint}"
)
return
raise Exception(f"No checkpoint found for model hash {checkpoint}")
args.override_settings["sd_model_checkpoint"] = checkpoint_info.title
# load images from url or file if needed
if is_img2img:
init_images = args.init_images
if len(init_images) == 0:
raise Exception("At least one init image is required")
for i, image in enumerate(init_images):
init_images[i] = load_image_to_base64(image)
init_images[i] = encode_image_to_base64(image)
args.mask = load_image_to_base64(args.mask)
# handle custom controlnet args
if controlnet_args is not None:
if args.alwayson_scripts is None:
args.alwayson_scripts = {}
controlnets = []
for cnet in controlnet_args:
enabled = cnet.get("enabled", True)
cnet_image = cnet.get("image", None)
if not enabled:
continue
if not isinstance(cnet_image, dict):
log.error(f"[AgentScheduler] Controlnet image is required")
continue
image = cnet_image.get("image", None)
mask = cnet_image.get("mask", None)
if image is None:
log.error(f"[AgentScheduler] Controlnet image is required")
continue
# load controlnet images from url or file if needed
cnet_image["image"] = load_image_to_base64(image)
cnet_image["mask"] = load_image_to_base64(mask)
controlnets.append(cnet)
if len(controlnets) > 0:
args.alwayson_scripts["controlnet"] = {"args": controlnets}
args.mask = encode_image_to_base64(args.mask)
args.batch_size = len(init_images)
return args.dict()

View File

@ -1,3 +1,4 @@
import os
import json
import time
import traceback
@ -7,6 +8,7 @@ from pydantic import BaseModel
from datetime import datetime, timedelta
from typing import Any, Callable, Union, Optional
from fastapi import FastAPI
from PIL import Image
from modules import progress, shared, script_callbacks
from modules.call_queue import queue_lock, wrap_gradio_call
@ -19,12 +21,19 @@ from modules.api.models import (
)
from .db import TaskStatus, Task, task_manager
from .helpers import log, detect_control_net, get_component_by_elem_id
from .helpers import (
log,
detect_control_net,
get_component_by_elem_id,
get_dict_attribute,
)
from .task_helpers import (
encode_image_to_base64,
serialize_img2img_image_args,
deserialize_img2img_image_args,
serialize_controlnet_args,
deserialize_controlnet_args,
serialize_api_task_args,
map_ui_task_args_list_to_named_args,
map_named_args_to_ui_task_args_list,
)
@ -119,11 +128,18 @@ class TaskRunner:
)
def __serialize_api_task_args(
self, is_img2img: bool, script_args: list = [], **named_args
self,
is_img2img: bool,
script_args: list = [],
checkpoint: str = None,
**api_args,
):
# serialization steps are done in task_helpers.register_api_task
override_settings = named_args.get("override_settings", {})
checkpoint = override_settings.get("sd_model_checkpoint", None)
named_args = serialize_api_task_args(
api_args, is_img2img, checkpoint=checkpoint
)
checkpoint = get_dict_attribute(
named_args, "override_settings.sd_model_checkpoint", None
)
return json.dumps(
{
@ -149,13 +165,19 @@ class TaskRunner:
script_args[i] = deserialize_controlnet_args(arg)
def __deserialize_api_task_args(self, is_img2img: bool, named_args: dict):
# API task use base64 images as input, no need to deserialize
pass
# load images from disk
if is_img2img:
init_images = named_args.get("init_images")
for i, img in enumerate(init_images):
if isinstance(img, str) and os.path.isfile(img):
print("loading image", img)
image = Image.open(img)
init_images[i] = encode_image_to_base64(image)
def parse_task_args(
self, params: str, script_params: bytes = None, deserialization: bool = True
):
parsed: dict[str, Any] = json.loads(params)
named_args.update({"save_images": True, "send_images": False})
def parse_task_args(self, task: Task, deserialization: bool = True):
parsed: dict[str, Any] = json.loads(task.params)
is_ui = parsed.get("is_ui", True)
is_img2img = parsed.get("is_img2img", None)
@ -198,13 +220,18 @@ class TaskRunner:
self.__total_pending_tasks += 1
def register_api_task(
self, task_id: str, api_task_id: str, is_img2img: bool, args: dict
self,
task_id: str,
api_task_id: str,
is_img2img: bool,
args: dict,
checkpoint: str = None,
):
progress.add_task_to_queue(task_id)
args = args.copy()
args.update({"save_images": True, "send_images": False})
params = self.__serialize_api_task_args(is_img2img, **args)
params = self.__serialize_api_task_args(
is_img2img, checkpoint=checkpoint, **args
)
task_type = "img2img" if is_img2img else "txt2img"
task_manager.add_task(
@ -216,7 +243,7 @@ class TaskRunner:
)
self.__total_pending_tasks += 1
def execute_task(self, task: Task, get_next_task: Callable):
def execute_task(self, task: Task, get_next_task: Callable[[], Task]):
while True:
if self.dispose:
break
@ -226,10 +253,7 @@ class TaskRunner:
is_img2img = task.type == "img2img"
log.info(f"[AgentScheduler] Executing task {task_id}")
task_args = self.parse_task_args(
task.params,
task.script_params,
)
task_args = self.parse_task_args(task)
task_meta = {
"is_img2img": is_img2img,
"is_ui": task_args.is_ui,
@ -439,7 +463,9 @@ class TaskRunner:
self.__run_callbacks("task_cleared")
def __on_image_saved(self, data: script_callbacks.ImageSaveParams):
self.__saved_images_path.append((data.filename, data.pnginfo.get("parameters", "")))
self.__saved_images_path.append(
(data.filename, data.pnginfo.get("parameters", ""))
)
def on_task_registered(self, callback: Callable):
"""Callback when a task is registered

File diff suppressed because it is too large Load Diff

View File

@ -31,6 +31,9 @@ initialized = False
checkpoint_current = "Current Checkpoint"
checkpoint_runtime = "Runtime Checkpoint"
ui_placement_as_tab = "As a tab"
ui_placement_append_to_main = "Append to main UI"
placement_under_generate = "Under Generate button"
placement_between_prompt_and_generate = "Between Prompt and Generate button"
@ -48,7 +51,7 @@ class Script(scripts.Script):
return "Agent Scheduler"
def show(self, is_img2img):
return True
return scripts.AlwaysVisible
def on_checkpoint_changed(self, checkpoint):
self.checkpoint_override = checkpoint
@ -409,15 +412,39 @@ def on_ui_settings():
section=section,
),
)
shared.opts.add_option(
"queue_ui_placement",
shared.OptionInfo(
ui_placement_as_tab,
"Task queue UI placement",
gr.Radio,
lambda: {
"choices": [
ui_placement_as_tab,
ui_placement_append_to_main,
]
},
section=section,
),
)
def on_app_started(block, app):
def on_app_started(block: gr.Blocks, app):
global task_runner
task_runner = get_instance(block)
task_runner.execute_pending_tasks_threading()
regsiter_apis(app, task_runner)
if (
getattr(shared.opts, "queue_ui_placement", "") == ui_placement_append_to_main
and block
):
with block:
with block.children[1]:
on_ui_tab()
if getattr(shared.opts, "queue_ui_placement", "") != ui_placement_append_to_main:
script_callbacks.on_ui_tabs(on_ui_tab)
script_callbacks.on_ui_tabs(on_ui_tab)
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_app_started(on_app_started)

File diff suppressed because one or more lines are too long

2
ui/.eslintignore Normal file
View File

@ -0,0 +1,2 @@
dist
.eslintrc.cjs

View File

@ -1,4 +1,5 @@
module.exports = {
root: true,
env: { browser: true, es2020: true },
extends: [
'eslint:recommended',
@ -7,8 +8,34 @@ module.exports = {
],
parser: '@typescript-eslint/parser',
parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
plugins: ['react-refresh'],
plugins: ['react-refresh', 'simple-import-sort'],
rules: {
'react-refresh/only-export-components': 'warn',
'simple-import-sort/imports': [
'error',
{
groups: [
// Side effect imports.
['^\\u0000'],
// Node.js builtins.
[`^(${require('module').builtinModules.join('|')})(/|$)`],
// Packages. `react` related packages come first.
['^react', '^\\w', '^@\\w'],
// Type
[`^(@@types)(/.*|$)`],
// Internal packages.
[
`^(~)(/.*|$)`,
],
// Parent imports. Put `..` last.
['^\\.\\.(?!/?$)', '^\\.\\./?$'],
// Other relative imports. Put same-folder imports and `.` last.
['^\\./(?=.*/)(?!/?$)', '^\\.(?!/?$)', '^\\./?$'],
// Style imports.
['^.+\\.s?css$'],
],
},
],
'@typescript-eslint/no-explicit-any': 'off'
},
}

View File

@ -35,6 +35,10 @@
/* ========================================================================= */
#tabs > #agent_scheduler_tabs {
margin-top: var(--layout-gap);
}
#agent_scheduler_pending_tasks_wrapper,
#agent_scheduler_history_wrapper {
border: none;

View File

@ -1,34 +1,36 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */
import { Grid, GridOptions } from 'ag-grid-community';
import { Notyf } from 'notyf';
import bookmark from '../assets/icons/bookmark.svg?raw';
import bookmarked from '../assets/icons/bookmark-filled.svg?raw';
import cancelIcon from '../assets/icons/cancel.svg?raw';
import deleteIcon from '../assets/icons/delete.svg?raw';
import playIcon from '../assets/icons/play.svg?raw';
import rotateIcon from '../assets/icons/rotate.svg?raw';
import searchIcon from '../assets/icons/search.svg?raw';
import { debounce } from '../utils/debounce';
import { extractArgs } from '../utils/extract-args';
import { formatDate } from '../utils/format-date';
import { createHistoryTasksStore } from './stores/history.store';
import { createPendingTasksStore } from './stores/pending.store';
import { createSharedStore } from './stores/shared.store';
import { ProgressResponse, ResponseStatus, Task, TaskStatus } from './types';
import 'ag-grid-community/styles/ag-grid.css';
import 'ag-grid-community/styles/ag-theme-alpine.css';
import 'notyf/notyf.min.css';
import './index.scss';
import { createPendingTasksStore } from './stores/pending.store';
import { ProgressResponse, ResponseStatus, Task, TaskStatus } from './types';
import { debounce } from '../utils/debounce';
import { extractArgs } from '../utils/extract-args';
import { createHistoryTasksStore } from './stores/history.store';
import { createSharedStore } from './stores/shared.store';
import deleteIcon from '../assets/icons/delete.svg?raw';
import cancelIcon from '../assets/icons/cancel.svg?raw';
import searchIcon from '../assets/icons/search.svg?raw';
import playIcon from '../assets/icons/play.svg?raw';
import rotateIcon from '../assets/icons/rotate.svg?raw';
import bookmark from '../assets/icons/bookmark.svg?raw';
import bookmarked from '../assets/icons/bookmark-filled.svg?raw';
const notyf = new Notyf();
declare global {
var country: string;
function gradioApp(): HTMLElement;
function randomId(): string;
function get_tab_index(name: string): number;
function create_submit_args(args: IArguments): any[];
function create_submit_args(args: any[]): any[];
function requestProgress(
id: string,
progressContainer: HTMLElement,
@ -37,12 +39,16 @@ declare global {
onProgress?: (res: ProgressResponse) => void,
): void;
function onUiLoaded(callback: () => void): void;
function submit_enqueue(): any[];
function submit_enqueue_img2img(): any[];
function notify(response: ResponseStatus): void;
function submit(...args: any[]): any[];
function submit_img2img(...args: any[]): any[];
function submit_enqueue(...args: any[]): any[];
function submit_enqueue_img2img(...args: any[]): any[];
function agent_scheduler_status_filter_changed(value: string): void;
}
const sharedStore = createSharedStore({
uiAsTab: true,
selectedTab: 'pending',
});
@ -98,7 +104,8 @@ const sharedGridOptions: GridOptions<Task> = {
{
field: 'params.prompt',
headerName: 'Prompt',
minWidth: 400,
minWidth: 200,
maxWidth: 400,
autoHeight: true,
wrapText: true,
cellStyle: { 'line-height': '24px', 'padding-top': '8px', 'padding-bottom': '8px' },
@ -106,7 +113,8 @@ const sharedGridOptions: GridOptions<Task> = {
{
field: 'params.negative_prompt',
headerName: 'Negative Prompt',
minWidth: 400,
minWidth: 200,
maxWidth: 400,
autoHeight: true,
wrapText: true,
cellStyle: { 'line-height': '24px', 'padding-top': '8px', 'padding-bottom': '8px' },
@ -143,7 +151,8 @@ const sharedGridOptions: GridOptions<Task> = {
headerName: 'Size',
minWidth: 110,
maxWidth: 110,
valueGetter: ({ data }) => (data ? `${data.params.width}x${data.params.height}` : ''),
valueGetter: ({ data }) =>
data?.params?.width ? `${data.params.width}x${data.params.height}` : '',
},
{
field: 'params.batch',
@ -151,11 +160,22 @@ const sharedGridOptions: GridOptions<Task> = {
minWidth: 100,
maxWidth: 100,
valueGetter: ({ data }) =>
data ? `${data.params.n_iter}x${data.params.batch_size}` : '1x1',
data?.params?.n_iter ? `${data.params.n_iter}x${data.params.batch_size}` : '1x1',
},
],
},
{ field: 'created_at', headerName: 'Date', minWidth: 200 },
{
field: 'created_at',
headerName: 'Queued At',
minWidth: 170,
valueFormatter: ({ value }) => value && formatDate(new Date(value)),
},
{
field: 'updated_at',
headerName: 'Updated At',
minWidth: 170,
valueFormatter: ({ value }) => value && formatDate(new Date(value)),
},
],
getRowId: ({ data }) => data.id,
@ -209,6 +229,8 @@ function notify(response: ResponseStatus) {
}
}
window.notify = notify;
function showTaskProgress(task_id: string, callback: () => void) {
const args = extractArgs(requestProgress);
@ -246,6 +268,16 @@ function showTaskProgress(task_id: string, callback: () => void) {
}
function initTabChangeHandler() {
sharedStore.subscribe((curr, prev) => {
if (!curr.uiAsTab || curr.selectedTab !== prev.selectedTab) {
if (curr.selectedTab === 'pending') {
pendingStore.refresh();
} else {
historyStore.refresh();
}
}
});
// watch for tab activation
const observer = new MutationObserver(function (mutationsList) {
mutationsList.forEach((styleChange) => {
@ -260,15 +292,19 @@ function initTabChangeHandler() {
historyStore.refresh();
}
} else if (tab.id === 'agent_scheduler_pending_tasks_tab') {
sharedStore.selectSelectedTab('pending');
pendingStore.refresh();
sharedStore.setSelectedTab('pending');
} else if (tab.id === 'agent_scheduler_history_tab') {
sharedStore.selectSelectedTab('history');
historyStore.refresh();
sharedStore.setSelectedTab('history');
}
});
});
observer.observe(document.getElementById('tab_agent_scheduler')!, { attributeFilter: ['style'] });
if (document.getElementById('tab_agent_scheduler')) {
observer.observe(document.getElementById('tab_agent_scheduler')!, {
attributeFilter: ['style'],
});
} else {
sharedStore.setState({ uiAsTab: false });
}
observer.observe(document.getElementById('agent_scheduler_pending_tasks_tab')!, {
attributeFilter: ['style'],
});
@ -280,33 +316,38 @@ function initTabChangeHandler() {
function initPendingTab() {
const store = pendingStore;
window.submit_enqueue = function submit_enqueue() {
var id = randomId();
var res = create_submit_args(arguments);
res[0] = id;
window.submit_enqueue = function submit_enqueue(...args) {
const res = window.submit(...args);
const btnEnqueue = document.querySelector('#txt2img_enqueue');
if (btnEnqueue) {
btnEnqueue.innerHTML = 'Queued';
setTimeout(() => {
btnEnqueue.innerHTML = 'Enqueue';
if (!sharedStore.getState().uiAsTab) {
if (sharedStore.getState().selectedTab === 'pending') {
pendingStore.refresh();
}
}
}, 1000);
}
return res;
};
window.submit_enqueue_img2img = function submit_enqueue_img2img() {
var id = randomId();
var res = create_submit_args(arguments);
res[0] = id;
res[1] = get_tab_index('mode_img2img');
window.submit_enqueue_img2img = function submit_enqueue_img2img(...args) {
const res = window.submit_img2img(...args);
const btnEnqueue = document.querySelector('#img2img_enqueue');
if (btnEnqueue) {
btnEnqueue.innerHTML = 'Queued';
setTimeout(() => {
btnEnqueue.innerHTML = 'Enqueue';
if (!sharedStore.getState().uiAsTab) {
if (sharedStore.getState().selectedTab === 'pending') {
pendingStore.refresh();
}
}
}, 1000);
}
@ -406,7 +447,22 @@ function initPendingTab() {
},
},
],
onGridReady: ({ api }) => {
onColumnMoved({ columnApi }) {
const colState = columnApi.getColumnState();
const colStateStr = JSON.stringify(colState);
localStorage.setItem('agent_scheduler:queue_col_state', colStateStr);
},
onSortChanged({ columnApi }) {
const colState = columnApi.getColumnState();
const colStateStr = JSON.stringify(colState);
localStorage.setItem('agent_scheduler:queue_col_state', colStateStr);
},
onColumnResized({ columnApi }) {
const colState = columnApi.getColumnState();
const colStateStr = JSON.stringify(colState);
localStorage.setItem('agent_scheduler:queue_col_state', colStateStr);
},
onGridReady: ({ api, columnApi }) => {
// init quick search input
const searchContainer = initSearchInput('#agent_scheduler_action_search');
const searchInput: HTMLInputElement = searchContainer.querySelector('input.ts-search-input')!;
@ -427,8 +483,15 @@ function initPendingTab() {
}
}
api.sizeColumnsToFit();
columnApi.autoSizeAllColumns();
});
// restore col state
const colStateStr = localStorage.getItem('agent_scheduler:queue_col_state');
if (colStateStr) {
const colState = JSON.parse(colStateStr);
columnApi.applyColumnState({ state: colState, applyOrder: true });
}
},
onRowDragEnd: ({ api, node, overNode }) => {
const id = node.data?.id;
@ -581,7 +644,22 @@ function initHistoryTab() {
],
rowSelection: 'single',
suppressRowDeselection: true,
onGridReady: ({ api }) => {
onColumnMoved({ columnApi }) {
const colState = columnApi.getColumnState();
const colStateStr = JSON.stringify(colState);
localStorage.setItem('agent_scheduler:history_col_state', colStateStr);
},
onSortChanged({ columnApi }) {
const colState = columnApi.getColumnState();
const colStateStr = JSON.stringify(colState);
localStorage.setItem('agent_scheduler:history_col_state', colStateStr);
},
onColumnResized({ columnApi }) {
const colState = columnApi.getColumnState();
const colStateStr = JSON.stringify(colState);
localStorage.setItem('agent_scheduler:history_col_state', colStateStr);
},
onGridReady: ({ api, columnApi }) => {
// init quick search input
const searchContainer = initSearchInput('#agent_scheduler_action_search_history');
const searchInput: HTMLInputElement = searchContainer.querySelector('input.ts-search-input')!;
@ -594,8 +672,15 @@ function initHistoryTab() {
store.subscribe((state) => {
api.setRowData(state.tasks);
api.sizeColumnsToFit();
columnApi.autoSizeAllColumns();
});
// restore col state
const colStateStr = localStorage.getItem('agent_scheduler:history_col_state');
if (colStateStr) {
const colState = JSON.parse(colStateStr);
columnApi.applyColumnState({ state: colState, applyOrder: true });
}
},
onSelectionChanged: (e) => {
const [selected] = e.api.getSelectedRows();
@ -619,7 +704,7 @@ let agentSchedulerInitialized = false;
onUiLoaded(function initAgentScheduler() {
// delay ui init until dom is available
if (!document.getElementById('tab_agent_scheduler')) {
if (!document.getElementById('agent_scheduler_tabs')) {
setTimeout(initAgentScheduler, 500);
return;
}

View File

@ -3,11 +3,12 @@ import { createStore } from 'zustand/vanilla';
type SelectedTab = 'history' | 'pending';
type SharedState = {
uiAsTab: boolean;
selectedTab: SelectedTab;
};
type SharedActions = {
selectSelectedTab: (tab: SelectedTab) => void;
setSelectedTab: (tab: SelectedTab) => void;
};
export const createSharedStore = (initialState: SharedState) => {
@ -15,7 +16,7 @@ export const createSharedStore = (initialState: SharedState) => {
const { getState, setState, subscribe } = store;
const actions: SharedActions = {
selectSelectedTab: (tab: SelectedTab) => {
setSelectedTab: (tab: SelectedTab) => {
setState({ selectedTab: tab });
},
};

View File

@ -35,20 +35,16 @@
}
}
.ag-cell.task-running .ag-cell-wrapper {
.ag-cell.task-running {
@apply !text-blue-500;
animation: 1s blink ease infinite;
}
/* .ag-cell.task-done .ag-cell-wrapper {
@apply !text-green-500;
} */
.ag-cell.task-failed .ag-cell-wrapper {
.ag-cell.task-failed {
@apply !text-red-500;
}
.ag-cell.task-interrupted .ag-cell-wrapper {
.ag-cell.task-interrupted {
@apply !text-gray-400;
}
}

View File

@ -1,4 +1,4 @@
export const debounce = (fn: Function, ms = 300) => {
export const debounce = (fn: (...args: any[]) => any, ms = 300) => {
let timeoutId: ReturnType<typeof setTimeout>;
return function (this: any, ...args: any[]) {
clearTimeout(timeoutId);

View File

@ -1,4 +1,4 @@
export const extractArgs = (func: Function) => {
export const extractArgs = (func: (...args: any[]) => any) => {
return (func + '')
.replace(/[/][/].*$/gm, '') // strip single-line comments
.replace(/\s+/g, '') // strip white space

View File

@ -0,0 +1,10 @@
export function formatDate(date: Date): string {
const year = date.getFullYear();
const month = (date.getMonth() + 1).toString().padStart(2, '0');
const day = date.getDate().toString().padStart(2, '0');
const hours = date.getHours().toString().padStart(2, '0');
const minutes = date.getMinutes().toString().padStart(2, '0');
const seconds = date.getSeconds().toString().padStart(2, '0');
return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`;
}