chore: feature updates

- add a flag to enable/disable queue auto processing
- add queue button placement setting
- add a flag to hide the custom checkpoint select
- rewrite frontend code in typescript
- extract serialization logic to task_helpers
- bugs fixing
pull/7/head
Tung Nguyen 2023-06-01 04:58:14 +07:00
parent 21d150a0bd
commit 4f7339468c
23 changed files with 38042 additions and 591 deletions

1
.gitignore vendored
View File

@ -37,4 +37,3 @@ notification.mp3
*.db
*.sqlite
*.sqlite3
tailwind.*

34722
javascript/extension.mjs Normal file

File diff suppressed because one or more lines are too long

View File

@ -5,7 +5,7 @@ from gradio.routes import App
import modules.shared as shared
from modules import progress, script_callbacks, sd_samplers
from scripts.db import TaskStatus, AppStateKey, task_manager, state_manager
from scripts.db import TaskStatus, task_manager
from scripts.models import QueueStatusResponse
from scripts.task_runner import TaskRunner, get_instance
from scripts.helpers import log
@ -29,8 +29,8 @@ def regsiter_apis(app: App):
task_args = TaskRunner.instance.parse_task_args(
task.params, task.script_params, deserialization=False
)
named_args = task_args["named_args"]
named_args["checkpoint"] = task_args["checkpoint"]
named_args = task_args.named_args
named_args["checkpoint"] = task_args.checkpoint
sampler_index = named_args.get("sampler_index", None)
if sampler_index is not None:
named_args["sampler_name"] = sd_samplers.samplers[
@ -103,21 +103,24 @@ def regsiter_apis(app: App):
@app.post("/agent-scheduler/v1/pause")
def pause_queue():
state_manager.set_value(AppStateKey.QueueState, "paused")
# state_manager.set_value(AppStateKey.QueueState, "paused")
shared.opts.queue_paused = True
return {"success": True, "message": f"Queue is paused"}
@app.post("/agent-scheduler/v1/resume")
def resume_queue():
state_manager.set_value(AppStateKey.QueueState, "running")
# state_manager.set_value(AppStateKey.QueueState, "running")
shared.opts.queue_paused = False
TaskRunner.instance.execute_pending_tasks_threading()
return {"success": True, "message": f"Queue is resumed"}
def on_app_started(block, app: App):
global task_runner
task_runner = get_instance(block)
if block is not None:
global task_runner
task_runner = get_instance(block)
regsiter_apis(app)
regsiter_apis(app)
script_callbacks.on_app_started(on_app_started)

View File

@ -27,11 +27,15 @@ def init():
inspector = inspect(engine)
with engine.connect() as conn:
# check if table task has column result and add it if not
task_columns = inspector.get_columns("task")
# add result column
if not any(col["name"] == "result" for col in task_columns):
conn.execute(text("ALTER TABLE task ADD COLUMN result TEXT"))
# add api_task_id column
if not any(col["name"] == "api_task_id" for col in task_columns):
conn.execute(text("ALTER TABLE task ADD COLUMN api_task_id VARCHAR(64)"))
params_column = next(col for col in task_columns if col["name"] == "params")
if version > "1" and not isinstance(params_column["type"], Text):
transaction = conn.begin()

View File

@ -22,6 +22,7 @@ class Task(TaskModel):
def __init__(
self,
id: str = "",
api_task_id: str = None,
type: str = "unknown",
params: str = "",
script_params: bytes = b"",
@ -35,6 +36,7 @@ class Task(TaskModel):
super().__init__(
id=id,
api_task_id=api_task_id,
type=type,
params=params,
status=status,
@ -44,6 +46,7 @@ class Task(TaskModel):
updated_at=created_at,
)
self.id: str = id
self.api_task_id: str = api_task_id
self.type: str = type
self.params: str = params
self.script_params: bytes = script_params
@ -60,6 +63,7 @@ class Task(TaskModel):
def from_table(table: "TaskTable"):
return Task(
id=table.id,
api_task_id=table.api_task_id,
type=table.type,
params=table.params,
script_params=table.script_params,
@ -72,6 +76,7 @@ class Task(TaskModel):
def to_table(self):
return TaskTable(
id=self.id,
api_task_id=self.api_task_id,
type=self.type,
params=self.params,
script_params=self.script_params,
@ -84,6 +89,7 @@ class TaskTable(Base):
__tablename__ = "task"
id = Column(String(64), primary_key=True)
api_task_id = Column(String(64), nullable=True)
type = Column(String(20), nullable=False) # txt2img or img2txt
params = Column(Text, nullable=False) # task args
script_params = Column(LargeBinary, nullable=False) # script args

View File

@ -20,6 +20,7 @@ class QueueStatusAPI(BaseModel):
class TaskModel(BaseModel):
id: str = Field(title="Task Id")
api_task_id: Optional[str] = Field(title="API Task Id", default=None)
type: str = Field(title="Task Type", description="Either txt2img or img2img")
status: str = Field(title="Task Status", description="Either pending, running, done or failed")
params: str = Field(title="Task Parameters", description="The parameters of the task in JSON format")
@ -34,6 +35,7 @@ class TaskModel(BaseModel):
datetime: convert_datetime_to_iso_8601_with_z_suffix
}
class QueueStatusResponse(BaseModel):
current_task_id: Optional[str] = Field(title="Current Task Id", description="The on progress task id")
pending_tasks: List[TaskModel] = Field(title="Pending Tasks", description="The pending tasks in the queue")

441
scripts/task_helpers.py Normal file
View File

@ -0,0 +1,441 @@
import os
import io
import zlib
import base64
import inspect
import requests
import numpy as np
from enum import Enum
from PIL import Image, ImageOps, ImageChops, ImageEnhance, ImageFilter
from typing import Optional, List
from pydantic import BaseModel, Field
from modules import sd_samplers, scripts
from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.sd_models import CheckpointInfo, get_closet_checkpoint_match
from modules.txt2img import txt2img
from modules.img2img import img2img
from modules.api.models import (
StableDiffusionTxt2ImgProcessingAPI,
StableDiffusionImg2ImgProcessingAPI,
)
from scripts.helpers import log
img2img_image_args_by_mode: dict[int, list[list[str]]] = {
0: [["init_img"]],
1: [["sketch"]],
2: [["init_img_with_mask", "image"], ["init_img_with_mask", "mask"]],
3: [["inpaint_color_sketch"], ["inpaint_color_sketch_orig"]],
4: [["init_img_inpaint"], ["init_mask_inpaint"]],
}
class ControlNetImage(BaseModel):
image: str # base64 or url
mask: Optional[str] = None # base64 or url
class ControlNetUnit(BaseModel):
enabled: Optional[bool] = True
module: Optional[str] = "none"
model: Optional[str] = None
image: ControlNetImage = None
weight: Optional[float] = 1.0
resize_mode: Optional[str] = None
low_vram: Optional[bool] = False
processor_res: Optional[int] = 512
threshold_a: Optional[float] = 64
threshold_b: Optional[float] = 64
guidance_start: Optional[float] = 0.0
guidance_end: Optional[float] = 1.0
pixel_perfect: Optional[bool] = False
control_mode: Optional[str] = "Balanced"
class BaseApiTaskArgs(BaseModel):
task_id: str = Field(exclude=True)
model_hash: str = Field(exclude=True)
prompt: Optional[str] = ""
styles: Optional[List[str]] = []
negative_prompt: Optional[str] = ""
seed: Optional[int] = -1
subseed: Optional[int] = 1
subseed_strength: Optional[int] = 0
seed_resize_from_h: Optional[int] = -1
seed_resize_from_w: Optional[int] = -1
sampler_name: Optional[str] = "DPM++ 2M Karras"
n_iter: Optional[int] = 1
batch_size: Optional[int] = 1
steps: Optional[int] = 20
cfg_scale: Optional[int] = 7.0
restore_faces: Optional[bool] = False
tiling: Optional[bool] = False
width: Optional[int] = 512
height: Optional[int] = 512
script_name: Optional[str] = None
controlnet_args: Optional[List[ControlNetUnit]] = Field(exclude=True, default=[])
override_settings: Optional[dict] = Field(default={})
class Txt2ImgApiTaskArgs(BaseApiTaskArgs):
enable_hr: Optional[bool] = False
denoising_strength: Optional[int] = 0
hr_scale: Optional[int] = 1
hr_upscaler: Optional[str] = "Latent"
hr_second_pass_steps: Optional[int] = 0
hr_resize_x: Optional[int] = 0
hr_resize_y: Optional[int] = 0
class Img2ImgApiTaskArgs(BaseApiTaskArgs):
init_images: List[str]
mask: Optional[str] = None
resize_mode: Optional[int] = 0
denoising_strength: Optional[int] = 0.75
mask_blur: Optional[int] = 4
inpainting_fill: Optional[int] = 0
inpaint_full_res: Optional[bool] = True
inpaint_full_res_padding: Optional[int] = 0
inpainting_mask_invert: Optional[int] = 0
initial_noise_multiplier: Optional[float] = 0.0
def load_image_from_url(url: str):
try:
response = requests.get(url)
buffer = io.BytesIO(response.content)
return Image.open(buffer)
except Exception as e:
log.error(f"[AgentScheduler] Error downloading image from url: {e}")
return None
def load_image(image: str):
if not isinstance(image, str):
return image
pil_image = None
if os.path.exists(image):
pil_image = Image.open(image)
elif image.startswith(("http://", "https://")):
pil_image = load_image_from_url(image)
return pil_image
def load_image_to_nparray(image: str):
pil_image = load_image(image)
return (
np.array(pil_image).astype("uint8")
if isinstance(pil_image, Image.Image)
else None
)
def encode_pil_to_base64(image: Image.Image):
with io.BytesIO() as output_bytes:
image.save(output_bytes, format="PNG")
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data).decode("utf-8")
def load_image_to_base64(image: str):
pil_image = load_image(image)
if not isinstance(pil_image, Image.Image):
return image
return encode_pil_to_base64(pil_image)
def __serialize_image(image):
if isinstance(image, np.ndarray):
shape = image.shape
data = base64.b64encode(zlib.compress(image.tobytes())).decode()
return {"shape": shape, "data": data, "cls": "ndarray"}
elif isinstance(image, Image.Image):
size = image.size
mode = image.mode
data = base64.b64encode(zlib.compress(image.tobytes())).decode()
return {
"size": size,
"mode": mode,
"data": data,
"cls": "Image",
}
else:
return image
def __deserialize_image(image_str):
if isinstance(image_str, dict) and image_str.get("cls", None):
cls = image_str["cls"]
data = zlib.decompress(base64.b64decode(image_str["data"]))
if cls == "ndarray":
shape = tuple(image_str["shape"])
image = np.frombuffer(data, dtype=np.uint8)
return image.reshape(shape)
else:
size = tuple(image_str["size"])
mode = image_str["mode"]
return Image.frombytes(mode, size, data)
else:
return image_str
def serialize_img2img_image_args(args: dict):
for mode, image_args in img2img_image_args_by_mode.items():
for keys in image_args:
if mode != args["mode"]:
# set None to unused image args to save space
args[keys[0]] = None
elif len(keys) == 1:
image = args.get(keys[0], None)
args[keys[0]] = __serialize_image(image)
else:
value = args.get(keys[0], {})
image = value.get(keys[1], None)
value[keys[1]] = __serialize_image(image)
args[keys[0]] = value
def deserialize_img2img_image_args(args: dict):
for mode, image_args in img2img_image_args_by_mode.items():
if mode != args["mode"]:
continue
for keys in image_args:
if len(keys) == 1:
image = args.get(keys[0], None)
args[keys[0]] = __deserialize_image(image)
else:
value = args.get(keys[0], {})
image = value.get(keys[1], None)
value[keys[1]] = __deserialize_image(image)
args[keys[0]] = value
def serialize_controlnet_args(cnet_unit):
args: dict = cnet_unit.__dict__
args["is_cnet"] = True
for k, v in args.items():
if k == "image" and v is not None:
args[k] = {
"image": __serialize_image(v["image"]),
"mask": __serialize_image(v["mask"])
if v.get("mask", None) is not None
else None,
}
if isinstance(v, Enum):
args[k] = v.value
return args
def deserialize_controlnet_args(args: dict):
# args.pop("is_cnet", None)
for k, v in args.items():
if k == "image" and v is not None:
args[k] = {
"image": __deserialize_image(v["image"]),
"mask": __deserialize_image(v["mask"])
if v.get("mask", None) is not None
else None,
}
return args
def map_ui_task_args_list_to_named_args(
args: list, is_img2img: bool, checkpoint: str = None
):
args_name = []
if is_img2img:
args_name = inspect.getfullargspec(img2img).args
else:
args_name = inspect.getfullargspec(txt2img).args
named_args = dict(zip(args_name, args[0 : len(args_name)]))
script_args = args[len(args_name) :]
if checkpoint is not None:
override_settings_texts = named_args.get("override_settings_texts", [])
override_settings_texts.append("Model hash: " + checkpoint)
named_args["override_settings_texts"] = override_settings_texts
return (
named_args,
script_args,
)
def map_ui_task_args_to_api_task_args(
named_args: dict, script_args: list, is_img2img: bool
):
api_task_args: dict = named_args.copy()
prompt_styles = api_task_args.pop("prompt_styles", [])
api_task_args["styles"] = prompt_styles
sampler_index = api_task_args.pop("sampler_index", 0)
api_task_args["sampler_name"] = sd_samplers.samplers[sampler_index].name
override_settings_texts = api_task_args.pop("override_settings_texts", [])
api_task_args["override_settings"] = create_override_settings_dict(
override_settings_texts
)
if is_img2img:
mode = api_task_args.pop("mode", 0)
for arg_mode, image_args in img2img_image_args_by_mode.items():
if mode != arg_mode:
for keys in image_args:
api_task_args.pop(keys[0], None)
# the logic below is copied from modules/img2img.py
if mode == 0:
image = api_task_args.pop("init_img").convert("RGB")
mask = None
elif mode == 1:
image = api_task_args.pop("sketch").convert("RGB")
mask = None
elif mode == 2:
init_img_with_mask: dict = api_task_args.pop("init_img_with_mask")
image = init_img_with_mask.get("image").convert("RGB")
mask = init_img_with_mask.get("mask")
alpha_mask = (
ImageOps.invert(image.split()[-1])
.convert("L")
.point(lambda x: 255 if x > 0 else 0, mode="1")
)
mask = ImageChops.lighter(alpha_mask, mask.convert("L")).convert("L")
elif mode == 3:
image = api_task_args.pop("inpaint_color_sketch")
orig = api_task_args.pop("inpaint_color_sketch_orig") or image
mask_alpha = api_task_args.pop("mask_alpha", 0)
mask_blur = api_task_args.get("mask_blur", 4)
pred = np.any(np.array(image) != np.array(orig), axis=-1)
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
blur = ImageFilter.GaussianBlur(mask_blur)
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
image = image.convert("RGB")
elif mode == 4:
image = api_task_args.pop("init_img_inpaint")
mask = api_task_args.pop("init_mask_inpaint")
else:
raise Exception(f"Batch mode is not supported yet")
image = ImageOps.exif_transpose(image)
api_task_args["init_images"] = [encode_pil_to_base64(image)]
api_task_args["mask"] = encode_pil_to_base64(mask) if mask is not None else None
selected_scale_tab = api_task_args.pop("selected_scale_tab", 0)
scale_by = api_task_args.pop("scale_by", 1)
if selected_scale_tab == 1:
api_task_args["width"] = int(image.width * scale_by)
api_task_args["height"] = int(image.height * scale_by)
else:
hr_sampler_index = api_task_args.pop("hr_sampler_index", 0)
api_task_args["hr_sampler_name"] = (
sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name
if hr_sampler_index != 0
else None
)
# script
script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
script_id = script_args[0]
if script_id == 0:
api_task_args["script_name"] = None
api_task_args["script_args"] = []
else:
script = script_runner.selectable_scripts[script_id - 1]
api_task_args["script_name"] = script.title.lower()
api_task_args["script_args"] = script_args[script.args_from : script.args_to]
# alwayson scripts
alwayson_scripts = api_task_args.get("alwayson_scripts", None)
if not alwayson_scripts or not isinstance(alwayson_scripts, dict):
alwayson_scripts = {}
api_task_args["alwayson_scripts"] = alwayson_scripts
for script in script_runner.alwayson_scripts:
alwayson_script_args = script_args[script.args_from : script.args_to]
if script.title.lower() == "controlnet":
for i, cnet_args in enumerate(alwayson_script_args):
alwayson_script_args[i] = serialize_controlnet_args(cnet_args)
alwayson_scripts[script.title.lower()] = {"args": alwayson_script_args}
return api_task_args
def serialize_api_task_args(params: dict, is_img2img: bool):
# pop out custom params
model_hash = params.pop("model_hash", None)
controlnet_args = params.pop("controlnet_args", None)
args = (
StableDiffusionImg2ImgProcessingAPI(**params)
if is_img2img
else StableDiffusionTxt2ImgProcessingAPI(**params)
)
if args.override_settings is None:
args.override_settings = {}
if model_hash is None:
model_hash = args.override_settings.get("sd_model_checkpoint", None)
if model_hash is None:
log.error("[AgentScheduler] API task must supply model hash")
return
checkpoint: CheckpointInfo = get_closet_checkpoint_match(model_hash)
if not checkpoint:
log.warn(f"[AgentScheduler] No checkpoint found for model hash {model_hash}")
return
args.override_settings["sd_model_checkpoint"] = checkpoint.title
# load images from url or file if needed
if is_img2img:
init_images = args.init_images
for i, image in enumerate(init_images):
init_images[i] = load_image_to_base64(image)
args.mask = load_image_to_base64(args.mask)
# handle custom controlnet args
if controlnet_args is not None:
if args.alwayson_scripts is None:
args.alwayson_scripts = {}
controlnets = []
for cnet in controlnet_args:
enabled = cnet.get("enabled", True)
cnet_image = cnet.get("image", None)
if not enabled:
continue
if not isinstance(cnet_image, dict):
log.error(f"[AgentScheduler] Controlnet image is required")
continue
image = cnet_image.get("image", None)
mask = cnet_image.get("mask", None)
if image is None:
log.error(f"[AgentScheduler] Controlnet image is required")
continue
# load controlnet images from url or file if needed
cnet_image["image"] = load_image_to_base64(image)
cnet_image["mask"] = load_image_to_base64(mask)
controlnets.append(cnet)
if len(controlnets) > 0:
args.alwayson_scripts["controlnet"] = {"args": controlnets}
return args.dict()

View File

@ -1,18 +1,13 @@
import io
import sys
import json
import time
import zlib
import base64
import pickle
import inspect
import traceback
import threading
import numpy as np
from pydantic import BaseModel
from datetime import datetime, timedelta
from enum import Enum
from PIL import Image, PngImagePlugin
from typing import Any, Callable, Union
from fastapi import FastAPI
@ -20,22 +15,29 @@ from modules import progress, shared, script_callbacks
from modules.call_queue import queue_lock, wrap_gradio_call
from modules.txt2img import txt2img
from modules.img2img import img2img
from modules.api.api import Api, encode_pil_to_base64
from modules.api.api import Api
from modules.api.models import (
StableDiffusionTxt2ImgProcessingAPI,
StableDiffusionImg2ImgProcessingAPI,
)
from scripts.db import TaskStatus, AppStateKey, Task, task_manager, state_manager
from scripts.db import TaskStatus, Task, task_manager
from scripts.helpers import log, detect_control_net, get_component_by_elem_id
from scripts.task_helpers import (
serialize_img2img_image_args,
deserialize_img2img_image_args,
serialize_controlnet_args,
deserialize_controlnet_args,
map_ui_task_args_list_to_named_args,
)
img2img_image_args_by_mode: dict[int, list[list[str]]] = {
0: [["init_img"]],
1: [["sketch"]],
2: [["init_img_with_mask", "image"], ["init_img_with_mask", "mask"]],
3: [["inpaint_color_sketch"], ["inpaint_color_sketch_orig"]],
4: [["init_img_inpaint"], ["init_mask_inpaint"]],
}
class ParsedTaskArgs(BaseModel):
args: list[Any]
named_args: dict[str, Any]
script_args: list[Any]
checkpoint: str
is_ui: bool
class TaskRunner:
@ -75,103 +77,22 @@ class TaskRunner:
@property
def paused(self) -> bool:
return state_manager.get_value(AppStateKey.QueueState) == "paused"
def __serialize_image(self, image):
if isinstance(image, np.ndarray):
shape = image.shape
data = base64.b64encode(zlib.compress(image.tobytes())).decode()
return {"shape": shape, "data": data, "cls": "ndarray"}
elif isinstance(image, Image.Image):
size = image.size
mode = image.mode
data = base64.b64encode(zlib.compress(image.tobytes())).decode()
return {
"size": size,
"mode": mode,
"data": data,
"cls": "Image",
}
else:
return image
def __deserialize_image(self, image_str):
if isinstance(image_str, dict) and image_str.get("cls", None):
cls = image_str["cls"]
data = zlib.decompress(base64.b64decode(image_str["data"]))
if cls == "ndarray":
shape = tuple(image_str["shape"])
image = np.frombuffer(data, dtype=np.uint8)
return image.reshape(shape)
else:
size = tuple(image_str["size"])
mode = image_str["mode"]
return Image.frombytes(mode, size, data)
else:
return image_str
def __serialize_img2img_images(self, args: dict, image_args: list):
for keys in image_args:
if len(keys) == 1:
image = args.get(keys[0], None)
args[keys[0]] = self.__serialize_image(image)
else:
value = args.get(keys[0], {})
image = value.get(keys[1], None)
value[keys[1]] = self.__serialize_image(image)
args[keys[0]] = value
def __deserialize_img2img_images(self, args: dict, image_args: list):
for keys in image_args:
if len(keys) == 1:
image = args.get(keys[0], None)
args[keys[0]] = self.__deserialize_image(image)
else:
value = args.get(keys[0], {})
image = value.get(keys[1], None)
value[keys[1]] = self.__deserialize_image(image)
args[keys[0]] = value
return shared.opts.queue_paused
def __serialize_ui_task_args(self, is_img2img: bool, *args, checkpoint: str = None):
args_name = []
if is_img2img:
args_name = inspect.getfullargspec(img2img).args
else:
args_name = inspect.getfullargspec(txt2img).args
args = list(args)
named_args = dict(zip(args_name, args[0 : len(args_name)]))
script_args = args[len(args_name) :]
if checkpoint:
override_settings_texts = named_args.get("override_settings_texts", [])
override_settings_texts.append("Model hash: " + checkpoint)
named_args["override_settings_texts"] = override_settings_texts
named_args, script_args = map_ui_task_args_list_to_named_args(
list(args), is_img2img, checkpoint=checkpoint
)
# loop through named_args and serialize images
if is_img2img:
for mode, image_args in img2img_image_args_by_mode.items():
if mode == named_args["mode"]:
self.__serialize_img2img_images(named_args, image_args)
else:
# set None to unused image args to save space
for keys in image_args:
named_args[keys[0]] = None
serialize_img2img_image_args(named_args)
# loop through script_args and serialize controlnets
if self.UiControlNetUnit is not None:
for i, a in enumerate(script_args):
if isinstance(a, self.UiControlNetUnit):
script_args[i] = a.__dict__
script_args[i]["is_cnet"] = True
for k, v in script_args[i].items():
if k == "image" and v is not None:
script_args[i][k] = {
"image": self.__serialize_image(v["image"]),
"mask": self.__serialize_image(v["mask"]),
}
if isinstance(v, Enum):
script_args[i][k] = v.value
script_args[i] = serialize_controlnet_args(a)
return json.dumps(
{
@ -186,6 +107,7 @@ class TaskRunner:
def __serialize_api_task_args(
self, is_img2img: bool, script_args: list = [], **named_args
):
# serialization steps are done in task_helpers.register_api_task
override_settings = named_args.get("override_settings", {})
checkpoint = override_settings.get("sd_model_checkpoint", None)
@ -204,21 +126,17 @@ class TaskRunner:
):
# loop through image_args and deserialize images
if is_img2img:
for mode, image_args in img2img_image_args_by_mode.items():
if mode == named_args["mode"]:
self.__deserialize_img2img_images(named_args, image_args)
deserialize_img2img_image_args(named_args)
# loop through script_args and deserialize controlnets
if self.UiControlNetUnit is not None:
for i, arg in enumerate(script_args):
if isinstance(arg, dict) and arg.get("is_cnet", False):
arg.pop("is_cnet")
for k, v in arg.items():
if k == "image" and v is not None:
arg[k] = {
"image": self.__deserialize_image(v["image"]),
"mask": self.__deserialize_image(v["mask"]),
}
script_args[i] = deserialize_controlnet_args(arg)
def __deserialize_api_task_args(self, is_img2img: bool, named_args: dict):
# API task use base64 images as input, no need to deserialize
pass
def parse_task_args(
self, params: str, script_params: bytes, deserialization: bool = True
@ -237,16 +155,18 @@ class TaskRunner:
if is_ui and deserialization:
self.__deserialize_ui_task_args(is_img2img, named_args, script_args)
elif deserialization:
self.__deserialize_api_task_args(is_img2img, named_args)
args = list(named_args.values()) + script_args
return {
"args": args,
"named_args": named_args,
"script_args": script_args,
"checkpoint": checkpoint,
"is_ui": is_ui,
}
return ParsedTaskArgs(
args=args,
named_args=named_args,
script_args=script_args,
checkpoint=checkpoint,
is_ui=is_ui,
)
def register_ui_task(
self, task_id: str, is_img2img: bool, *args, checkpoint: str = None
@ -263,13 +183,19 @@ class TaskRunner:
)
self.__total_pending_tasks += 1
def register_api_task(self, task_id: str, is_img2img: bool, args: dict):
def register_api_task(
self, task_id: str, api_task_id: str, is_img2img: bool, args: dict
):
progress.add_task_to_queue(task_id)
args = args.copy()
args.update({"save_images": True, "send_images": False})
params = self.__serialize_api_task_args(is_img2img, **args)
task_type = "img2img" if is_img2img else "txt2img"
task_manager.add_task(Task(id=task_id, type=task_type, params=params))
task_manager.add_task(
Task(id=task_id, api_task_id=api_task_id, type=task_type, params=params)
)
self.__run_callbacks(
"task_registered", task_id, is_img2img=is_img2img, is_ui=False, args=params
@ -279,7 +205,11 @@ class TaskRunner:
def execute_task(self, task: Task, get_next_task: Callable):
while True:
if self.dispose:
sys.exit(0)
break
if self.paused:
log.info("[AgentScheduler] Runner is paused")
break
if progress.current_task is None:
task_id = task.id
@ -290,7 +220,11 @@ class TaskRunner:
task.params,
task.script_params,
)
task_meta = {"is_img2img": is_img2img, "is_ui": task_args["is_ui"]}
task_meta = {
"is_img2img": is_img2img,
"is_ui": task_args.is_ui,
"api_task_id": task.api_task_id,
}
self.__saved_images_path = []
self.__run_callbacks("task_started", task_id, **task_meta)
@ -333,7 +267,7 @@ class TaskRunner:
task = get_next_task()
if not task:
sys.exit(0)
break
def execute_pending_tasks_threading(self):
if self.paused:
@ -366,19 +300,19 @@ class TaskRunner:
return [
task.id,
task.type,
json.dumps(task_args["named_args"]),
json.dumps(task_args.named_args),
task.created_at.strftime("%Y-%m-%d %H:%M:%S"),
]
def __execute_task(self, task_id: str, is_img2img: bool, task_args: dict):
if task_args["is_ui"]:
return self.__execute_ui_task(task_id, is_img2img, *task_args["args"])
def __execute_task(self, task_id: str, is_img2img: bool, task_args: ParsedTaskArgs):
if task_args.is_ui:
return self.__execute_ui_task(task_id, is_img2img, *task_args.args)
else:
return self.__execute_api_task(
task_id,
is_img2img,
script_args=task_args["script_args"],
**task_args["named_args"],
script_args=task_args.script_args,
**task_args.named_args,
)
def __execute_ui_task(self, task_id: str, is_img2img: bool, *args):
@ -483,9 +417,12 @@ class TaskRunner:
def get_instance(block) -> TaskRunner:
if TaskRunner.instance is None:
txt2img_submit_button = get_component_by_elem_id(block, "txt2img_generate")
UiControlNetUnit = detect_control_net(block, txt2img_submit_button)
TaskRunner(UiControlNetUnit)
if block is not None:
txt2img_submit_button = get_component_by_elem_id(block, "txt2img_generate")
UiControlNetUnit = detect_control_net(block, txt2img_submit_button)
TaskRunner(UiControlNetUnit)
else:
TaskRunner()
def on_before_reload():
# Tell old instance to stop

View File

@ -6,7 +6,7 @@ from modules.ui import create_refresh_button
from scripts.task_runner import TaskRunner, get_instance
from scripts.helpers import compare_components_with_ids, get_components_by_ids
from scripts.db import init, state_manager, AppStateKey
from scripts.db import init
task_runner: TaskRunner = None
initialized = False
@ -14,6 +14,10 @@ initialized = False
checkpoint_current = "Current Checkpoint"
checkpoint_runtime = "Runtime Checkpoint"
placement_above_generate = "Above Generate button"
placement_under_generate = "Under Generate button"
placement_between_prompt_and_generate = "Between Prompt and Generate button"
class Script(scripts.Script):
def __init__(self):
@ -77,10 +81,11 @@ class Script(scripts.Script):
outputs=fn_block.outputs,
show_progress=False,
)
id_part = "img2img" if is_img2img else "txt2img"
with root:
with generate.parent:
id_part = "img2img" if is_img2img else "txt2img"
with gr.Row(elem_id=f"{id_part}_enqueue_wrapper"):
with gr.Row(elem_id=f"{id_part}_enqueue_wrapper") as row:
if not shared.opts.queue_button_hide_checkpoint:
checkpoint = gr.Dropdown(
choices=get_checkpoint_choices(),
value=checkpoint_current,
@ -93,13 +98,36 @@ class Script(scripts.Script):
lambda: {"choices": get_checkpoint_choices()},
f"refresh_{id_part}_checkpoint",
)
submit = gr.Button(
"Enqueue", elem_id=f"{id_part}_enqueue", variant="primary"
checkpoint.change(
fn=self.on_checkpoint_changed, inputs=[checkpoint]
)
checkpoint.change(fn=self.on_checkpoint_changed, inputs=[checkpoint])
submit = gr.Button(
"Enqueue", elem_id=f"{id_part}_enqueue", variant="primary"
)
submit.click(**args)
# relocation the enqueue button
root.children.pop()
if shared.opts.queue_button_placement == placement_between_prompt_and_generate:
if is_img2img:
# add to the iterrogate div
parent = generate.parent.parent.parent.children[1]
parent.add(row)
else:
# insert after the prompts
parent = generate.parent.parent.parent
row.parent = parent
parent.children.insert(1, row)
elif shared.opts.queue_button_placement == placement_under_generate:
# insert after the tools div
parent = generate.parent.parent
parent.children.insert(2, row)
else:
# insert after before the generate button
parent = generate.parent.parent
parent.children.insert(0, row)
if cnet_dependency is not None:
cnet_fn_block = next(
fn
@ -147,10 +175,6 @@ def get_checkpoint_choices():
return choices
def is_queue_paused():
return state_manager.get_value(AppStateKey.QueueState) == "paused"
def on_ui_tab(**_kwargs):
global initialized
if not initialized:
@ -158,18 +182,25 @@ def on_ui_tab(**_kwargs):
init()
with gr.Blocks(analytics_enabled=False) as scheduler_tab:
gr.Textbox(
shared.opts.queue_button_placement,
elem_id="agent_scheduler_queue_button_placement",
show_label=False,
visible=False,
interactive=False,
)
with gr.Row(elem_id="agent_scheduler_pending_tasks_wrapper"):
with gr.Column(scale=1):
with gr.Group(elem_id="agent_scheduler_actions"):
paused = is_queue_paused()
paused = shared.opts.queue_paused
pause = gr.Button(
gr.Button(
"Pause",
elem_id="agent_scheduler_action_pause",
variant="stop",
visible=not paused,
)
resume = gr.Button(
gr.Button(
"Resume",
elem_id="agent_scheduler_action_resume",
variant="primary",
@ -195,11 +226,53 @@ def on_ui_tab(**_kwargs):
return [(scheduler_tab, "Agent Scheduler", "agent_scheduler")]
def on_ui_settings():
section = ("agent_scheduler", "Agent Scheduler")
shared.opts.add_option(
"queue_paused",
shared.OptionInfo(
False,
"Disable queue auto processing",
gr.Checkbox,
{"interactive": True},
section=section,
),
)
shared.opts.add_option(
"queue_button_placement",
shared.OptionInfo(
placement_above_generate,
"Queue button placement",
gr.Radio,
lambda: {
"choices": [
placement_above_generate,
placement_under_generate,
placement_between_prompt_and_generate,
]
},
section=section,
),
)
shared.opts.add_option(
"queue_button_hide_checkpoint",
shared.OptionInfo(
True,
"Hide the checkpoint dropdown",
gr.Checkbox,
{},
section=section,
),
)
def on_app_started(block, _):
global task_runner
task_runner = get_instance(block)
task_runner.execute_pending_tasks_threading()
if block is not None:
global task_runner
task_runner = get_instance(block)
task_runner.execute_pending_tasks_threading()
script_callbacks.on_ui_tabs(on_ui_tab)
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_app_started(on_app_started)

431
style.css

File diff suppressed because one or more lines are too long

14
ui/.eslintrc.cjs Normal file
View File

@ -0,0 +1,14 @@
module.exports = {
env: { browser: true, es2020: true },
extends: [
'eslint:recommended',
'plugin:@typescript-eslint/recommended',
'plugin:react-hooks/recommended',
],
parser: '@typescript-eslint/parser',
parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
plugins: ['react-refresh'],
rules: {
'react-refresh/only-export-components': 'warn',
},
}

24
ui/.gitignore vendored Normal file
View File

@ -0,0 +1,24 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?

35
ui/package.json Normal file
View File

@ -0,0 +1,35 @@
{
"name": "ui",
"private": true,
"version": "0.0.0",
"scripts": {
"dev": "vite",
"build": "tsc && yarn build:extension",
"build:extension": "vite build --config vite.extension.ts",
"lint": "eslint src --ext ts,tsx --report-unused-disable-directives --max-warnings 0",
"preview": "vite preview"
},
"dependencies": {
"ag-grid-community": "^29.3.5",
"notyf": "^3.10.0",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"rxjs": "^7.8.1"
},
"devDependencies": {
"@types/node": "^18",
"@types/react": "^18.0.37",
"@types/react-dom": "^18.0.11",
"@typescript-eslint/eslint-plugin": "^5.59.0",
"@typescript-eslint/parser": "^5.59.0",
"@vitejs/plugin-react": "^4.0.0",
"autoprefixer": "^10.4.14",
"eslint": "^8.38.0",
"eslint-plugin-react-hooks": "^4.6.0",
"eslint-plugin-react-refresh": "^0.3.4",
"postcss": "^8.4.24",
"tailwindcss": "^3.3.2",
"typescript": "^5.0.2",
"vite": "^4.3.9"
}
}

6
ui/postcss.config.js Normal file
View File

@ -0,0 +1,6 @@
module.exports = {
plugins: {
tailwindcss: {},
autoprefixer: {},
},
}

160
ui/src/extension/index.css Normal file
View File

@ -0,0 +1,160 @@
@tailwind components;
@tailwind utilities;
@layer components {
.ts-search {
@apply relative w-full max-w-xs ml-auto;
}
.ts-search-input {
@apply bg-gray-50 border border-gray-300 text-gray-900 text-sm !rounded-md focus:ring-blue-500 focus:border-blue-500 block !pl-10 !p-2 w-full dark:!bg-gray-700 dark:!border-gray-600 dark:!placeholder-gray-400 dark:!text-white dark:focus:ring-blue-500 dark:focus:border-blue-500;
}
.ts-search-icon {
@apply absolute inset-y-0 left-0 flex items-center dark:text-white pl-3 pointer-events-none;
}
.ts-btn-action {
@apply inline-flex items-center !px-2 !py-1 !m-0 text-sm font-medium border focus:z-10 focus:ring-2 disabled:opacity-50 disabled:hover:!bg-transparent disabled:cursor-not-allowed;
}
.ts-btn-run {
@apply !text-green-500 hover:!text-white border-green-500 hover:bg-green-600 rounded-l-md focus:ring-green-400 dark:border-green-500 dark:hover:bg-green-600 dark:focus:ring-green-900 disabled:hover:!text-green-500;
}
.ts-btn-delete {
@apply !text-red-500 hover:!text-white border-red-600 hover:bg-red-600 rounded-r-md focus:ring-red-300 dark:border-red-500 dark:hover:bg-red-600 dark:focus:ring-red-900;
}
@keyframes blink {
from,
to {
opacity: 0;
}
50% {
opacity: 1;
}
}
.task-running {
@apply !text-green-500;
animation: 1s blink ease infinite;
}
}
/* ========================================================================= */
#agent_scheduler_pending_tasks_wrapper {
gap: var(--layout-gap);
border: none;
box-shadow: none;
border-width: 0;
}
@media (max-width: 1024px) {
#agent_scheduler_pending_tasks_wrapper {
flex-wrap: wrap;
}
}
#agent_scheduler_pending_tasks_wrapper > div:last-child {
width: 100%;
max-width: 400px;
}
@media (min-width: 1920px) {
#agent_scheduler_pending_tasks_wrapper > div:last-child {
max-width: 512px;
}
}
#agent_scheduler_current_task_images {
width: 100%;
padding-top: 100%;
position: relative;
}
#agent_scheduler_current_task_images > div {
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
height: 100%;
}
#agent_scheduler_pending_tasks_wrapper {
justify-content: flex-end;
gap: var(--layout-gap);
padding: 0 var(--layout-gap) var(--layout-gap) var(--layout-gap);
}
#agent_scheduler_pending_tasks_wrapper > button {
flex: 0 0 auto;
}
#agent_scheduler_actions {
display: flex;
gap: var(--layout-gap);
}
#agent_scheduler_actions > button {
border-radius: var(--radius-lg) !important;
}
.ag-theme-alpine,
.ag-theme-alpine-dark {
--ag-row-height: 45px;
--ag-header-height: 45px;
/* --ag-grid-size: 6px; */
--ag-cell-horizontal-padding: calc(var(--ag-grid-size) * 2);
--body-text-color: 'inherit';
}
.cell-span {
border-bottom-color: var(--ag-border-color);
}
.cell-not-span {
opacity: 0;
}
.ag-row-hover .ag-cell {
background-color: transparent;
}
#txt2img_enqueue_wrapper,
#img2img_enqueue_wrapper {
min-width: 210px;
display: flex;
flex-direction: column;
gap: calc(var(--layout-gap) / 2);
}
#txt2img_enqueue_wrapper > div:first-child,
#img2img_enqueue_wrapper > div:first-child {
flex-direction: row;
flex-wrap: nowrap;
align-items: flex-start;
flex: 0 0 auto;
}
#txt2img_enqueue_wrapper .gradio-button,
#img2img_enqueue_wrapper .gradio-button,
#txt2img_enqueue_wrapper .gradio-dropdown .wrap-inner,
#img2img_enqueue_wrapper .gradio-dropdown .wrap-inner {
min-height: 36px;
max-height: 42px;
}
#img2img_toprow .interrogate-col.has-queue-button {
min-width: unset !important;
flex-direction: row !important;
gap: calc(var(--layout-gap) / 2) !important;
}
#img2img_toprow .interrogate-col.has-queue-button button {
margin: 0;
}
#agent_scheduler_current_task_progress .livePreview {
margin: 0;
}

470
ui/src/extension/index.ts Normal file
View File

@ -0,0 +1,470 @@
import * as rxjs from 'rxjs';
import type { Observer } from 'rxjs';
import { Grid, GridOptions } from 'ag-grid-community';
import { Notyf } from 'notyf';
import 'ag-grid-community/styles/ag-grid.css';
import 'ag-grid-community/styles/ag-theme-alpine.css';
import 'notyf/notyf.min.css';
import './index.css';
declare global {
var country: string;
function gradioApp(): HTMLElement;
function randomId(): string;
function get_tab_index(name: string): number;
function create_submit_args(args: IArguments): any[];
function requestProgress(
id: string,
progressContainer: HTMLElement,
imagesContainer: HTMLElement,
onDone: () => void,
): void;
function onUiLoaded(callback: () => void): void;
function submit_enqueue(): any[];
function submit_enqueue_img2img(): any[];
}
type Task = {
id: string;
api_task_id: string;
type: string;
status: string;
params: Record<string, any>;
priority: number;
result: string;
};
type AppState = {
current_task_id: string | null;
total_pending_tasks: number;
pending_tasks: Task[];
paused: boolean;
};
function initTaskScheduler() {
const notyf = new Notyf();
const subject = new rxjs.Subject<AppState>();
const store = {
subject,
subscribe: (callback: Partial<Observer<[AppState, AppState]>>) => {
return store.subject.pipe(rxjs.pairwise()).subscribe(callback);
},
refresh: async () => {
return fetch('/agent-scheduler/v1/queue?limit=1000')
.then((response) => response.json())
.then((data: AppState) => {
const pending_tasks = data.pending_tasks.map((item) => ({
...item,
params: JSON.parse(item.params as any),
status: item.id === data.current_task_id ? 'running' : 'pending',
}));
store.subject.next({
...data,
pending_tasks,
});
});
},
pauseQueue: async () => {
return fetch('/agent-scheduler/v1/pause', { method: 'POST' })
.then((response) => response.json())
.then((data) => {
if (data.success) {
notyf.success(data.message);
} else {
notyf.error(data.message);
}
return store.refresh();
});
},
resumeQueue: async () => {
return fetch('/agent-scheduler/v1/resume', { method: 'POST' })
.then((response) => response.json())
.then((data) => {
if (data.success) {
notyf.success(data.message);
} else {
notyf.error(data.message);
}
return store.refresh();
});
},
runTask: async (id: string) => {
return fetch(`/agent-scheduler/v1/run/${id}`, { method: 'POST' })
.then((response) => response.json())
.then((data) => {
if (data.success) {
notyf.success(data.message);
} else {
notyf.error(data.message);
}
return store.refresh();
});
},
deleteTask: async (id: string) => {
return fetch(`/agent-scheduler/v1/delete/${id}`, { method: 'POST' })
.then((response) => response.json())
.then((data) => {
if (data.success) {
notyf.success(data.message);
} else {
notyf.error(data.message);
}
return store.refresh();
});
},
moveTask: async (id: string, overId: string) => {
return fetch(`/agent-scheduler/v1/move/${id}/${overId}`, { method: 'POST' })
.then((response) => response.json())
.then((data) => {
if (data.success) {
notyf.success(data.message);
} else {
notyf.error(data.message);
}
return store.refresh();
});
},
};
store.subject.next({
current_task_id: null,
total_pending_tasks: 0,
pending_tasks: [],
paused: false,
});
window.submit_enqueue = function submit_enqueue() {
var id = randomId();
var res = create_submit_args(arguments);
res[0] = id;
const btnEnqueue = document.querySelector('#txt2img_enqueue');
if (btnEnqueue) {
btnEnqueue.innerHTML = 'Queued';
setTimeout(() => {
btnEnqueue.innerHTML = 'Enqueue';
store.refresh();
}, 1000);
}
return res;
};
window.submit_enqueue_img2img = function submit_enqueue_img2img() {
var id = randomId();
var res = create_submit_args(arguments);
res[0] = id;
res[1] = get_tab_index('mode_img2img');
const btnEnqueue = document.querySelector('#img2img_enqueue');
if (btnEnqueue) {
btnEnqueue.innerHTML = 'Queued';
setTimeout(() => {
btnEnqueue.innerHTML = 'Enqueue';
store.refresh();
}, 1000);
}
return res;
};
// detect queue button placement
const interrogateCol: HTMLDivElement = gradioApp().querySelector('.interrogate-col')!;
if (interrogateCol.childElementCount > 2) {
interrogateCol.classList.add('has-queue-button');
}
// watch for tab activation
const observer = new MutationObserver(function (mutationsList) {
const styleChange = mutationsList.find((mutation) => mutation.attributeName === 'style');
if (styleChange) {
const tab = styleChange.target as HTMLElement;
if (tab.style.display === 'block') {
store.refresh();
}
}
});
observer.observe(document.getElementById('tab_agent_scheduler')!, { attributes: true });
// init actions
const refreshButton = gradioApp().querySelector('#agent_scheduler_action_refresh')!;
const pauseButton = gradioApp().querySelector('#agent_scheduler_action_pause')!;
const resumeButton = gradioApp().querySelector('#agent_scheduler_action_resume')!;
refreshButton.addEventListener('click', store.refresh);
pauseButton.addEventListener('click', store.pauseQueue);
resumeButton.addEventListener('click', store.resumeQueue);
const searchContainer = gradioApp().querySelector('#agent_scheduler_action_search')!;
searchContainer.className = 'ts-search';
searchContainer.innerHTML = `
<div class="ts-search-icon">
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
<path d="M10 10m-7 0a7 7 0 1 0 14 0a7 7 0 1 0 -14 0"/>
<path d="M21 21l-6 -6"/>
</svg>
</div>
<input type="text" id="agent_scheduler_search_input" class="ts-search-input" placeholder="Search" required>
`;
// watch for current task id change
const onTaskIdChange = (id: string | null) => {
if (id) {
requestProgress(
id,
gradioApp().querySelector('#agent_scheduler_current_task_progress')!,
gradioApp().querySelector('#agent_scheduler_current_task_images')!,
() => {
setTimeout(() => {
store.refresh();
}, 1000);
},
);
}
};
store.subscribe({
next: ([prev, curr]) => {
if (prev.current_task_id !== curr.current_task_id) {
onTaskIdChange(curr.current_task_id);
}
if (curr.paused) {
pauseButton.classList.add('hide');
resumeButton.classList.remove('hide');
} else {
pauseButton.classList.remove('hide');
resumeButton.classList.add('hide');
}
},
});
// init grid
const deleteIcon = `
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
<path d="M4 7l16 0"/>
<path d="M10 11l0 6"/>
<path d="M14 11l0 6"/>
<path d="M5 7l1 12a2 2 0 0 0 2 2h8a2 2 0 0 0 2 -2l1 -12"/>
<path d="M9 7v-3a1 1 0 0 1 1 -1h4a1 1 0 0 1 1 1v3"/>
</svg>`;
const cancelIcon = `
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
<path d="M18 6l-12 12"/>
<path d="M6 6l12 12"/>
</svg>
`;
const pendingTasksGridOptions: GridOptions<Task> = {
// domLayout: 'autoHeight',
// default col def properties get applied to all columns
defaultColDef: {
sortable: false,
filter: true,
resizable: true,
suppressMenu: true,
},
// each entry here represents one column
columnDefs: [
{
field: 'id',
headerName: 'Task Id',
minWidth: 240,
maxWidth: 240,
pinned: 'left',
rowDrag: true,
cellClass: ({ data }) => [
data?.status === 'running' ? 'task-running' : '',
],
},
{
field: 'type',
headerName: 'Type',
minWidth: 80,
maxWidth: 80,
},
{
field: 'priority',
headerName: 'Priority',
hide: true,
},
{
headerName: 'Params',
children: [
{
field: 'params.prompt',
headerName: 'Prompt',
minWidth: 400,
autoHeight: true,
wrapText: true,
cellStyle: { 'line-height': '24px', 'padding-top': '8px', 'padding-bottom': '8px' },
},
{
field: 'params.negative_prompt',
headerName: 'Negative Prompt',
minWidth: 400,
autoHeight: true,
wrapText: true,
cellStyle: { 'line-height': '24px', 'padding-top': '8px', 'padding-bottom': '8px' },
},
{
field: 'params.checkpoint',
headerName: 'Checkpoint',
minWidth: 150,
maxWidth: 300,
valueFormatter: ({ value }) => value || 'System',
},
{
field: 'params.sampler_name',
headerName: 'Sampler',
width: 150,
minWidth: 150,
},
{
field: 'params.steps',
headerName: 'Steps',
minWidth: 80,
maxWidth: 80,
filter: 'agNumberColumnFilter',
},
{
field: 'params.cfg_scale',
headerName: 'CFG Scale',
width: 100,
minWidth: 100,
filter: 'agNumberColumnFilter',
},
{
field: 'params.size',
headerName: 'Size',
minWidth: 110,
maxWidth: 110,
valueGetter: ({ data }) => (data ? `${data.params.width}x${data.params.height}` : ''),
},
{
field: 'params.batch',
headerName: 'Batching',
minWidth: 100,
maxWidth: 100,
valueGetter: ({ data }) =>
data ? `${data.params.n_iter}x${data.params.batch_size}` : '1x1',
},
],
},
{ field: 'created_at', headerName: 'Date', minWidth: 200 },
{
headerName: 'Action',
pinned: 'right',
minWidth: 110,
maxWidth: 110,
resizable: false,
valueGetter: ({ data }) => data?.id,
cellRenderer: ({ api, value, data }: any) => {
const html = `
<div class="inline-flex rounded-md shadow-sm mt-1.5" role="group">
<button type="button" ${
data.status === 'running' ? 'disabled' : ''
} class="ts-btn-action ts-btn-run">
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
<path d="M7 4v16l13 -8z"/>
</svg>
</button>
<button type="button" class="ts-btn-action ts-btn-delete">
${data.status === 'pending' ? deleteIcon : cancelIcon}
</button>
</div>
`;
const placeholder = document.createElement('div');
placeholder.innerHTML = html;
const node = placeholder.firstElementChild!;
const btnRun = node.querySelector('button.ts-btn-run')!;
btnRun.addEventListener('click', () => {
console.log('run', value);
api.showLoadingOverlay();
store.runTask(value).then(() => api.hideOverlay());
});
const btnDelete = node.querySelector('button.ts-btn-delete')!;
btnDelete.addEventListener('click', () => {
console.log('delete', value);
api.showLoadingOverlay();
store.deleteTask(value).then(() => api.hideOverlay());
});
return node;
},
},
],
getRowId: ({ data }) => data.id,
rowData: [],
rowSelection: 'single', // allow rows to be selected
animateRows: true, // have rows animate to new positions when sorted
pagination: true,
paginationPageSize: 10,
suppressCopyRowsToClipboard: true,
suppressRowTransform: true,
suppressRowClickSelection: true,
enableBrowserTooltips: true,
onGridReady: ({ api }) => {
// init quick search input
const searchInput: HTMLInputElement = searchContainer.querySelector(
'input#agent_scheduler_search_input',
)!;
rxjs
.fromEvent(searchInput, 'keyup')
.pipe(rxjs.debounce(() => rxjs.interval(200)))
.subscribe((e) => {
api.setQuickFilter((e.target as HTMLInputElement).value);
});
store.subscribe({
next: ([_, newState]) => {
api.setRowData(newState.pending_tasks);
if (newState.current_task_id) {
const node = api.getRowNode(newState.current_task_id);
if (node) {
api.refreshCells({ rowNodes: [node], force: true });
}
}
api.sizeColumnsToFit();
},
});
// refresh the state
store.refresh();
},
onRowDragEnd: ({ api, node, overNode }) => {
const id = node.data?.id;
const overId = overNode?.data?.id;
if (id && overId && id !== overId) {
api.showLoadingOverlay();
store.moveTask(id, overId).then(() => api.hideOverlay());
}
},
};
const eGridDiv = gradioApp().querySelector<HTMLDivElement>(
'#agent_scheduler_pending_tasks_grid',
)!;
if (document.querySelector('.dark')) {
eGridDiv.className = 'ag-theme-alpine-dark';
}
eGridDiv.style.height = window.innerHeight - 240 + 'px';
new Grid(eGridDiv, pendingTasksGridOptions);
store.refresh();
}
onUiLoaded(initTaskScheduler);

1
ui/src/vite-env.d.ts vendored Normal file
View File

@ -0,0 +1 @@
/// <reference types="vite/client" />

8
ui/tailwind.config.js Normal file
View File

@ -0,0 +1,8 @@
module.exports = {
content: ["./src/**/*.{js,ts,tsx}"],
darkMode: 'class',
theme: {
extend: {},
},
plugins: [],
}

25
ui/tsconfig.json Normal file
View File

@ -0,0 +1,25 @@
{
"compilerOptions": {
"target": "ES2020",
"useDefineForClassFields": true,
"lib": ["ES2020", "DOM", "DOM.Iterable"],
"module": "ESNext",
"skipLibCheck": true,
/* Bundler mode */
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"resolveJsonModule": true,
"isolatedModules": true,
"noEmit": true,
"jsx": "react-jsx",
/* Linting */
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"noFallthroughCasesInSwitch": true
},
"include": ["src"],
"references": [{ "path": "./tsconfig.node.json" }]
}

10
ui/tsconfig.node.json Normal file
View File

@ -0,0 +1,10 @@
{
"compilerOptions": {
"composite": true,
"skipLibCheck": true,
"module": "ESNext",
"moduleResolution": "bundler",
"allowSyntheticDefaultImports": true
},
"include": ["vite.config.ts"]
}

21
ui/vite.config.ts Normal file
View File

@ -0,0 +1,21 @@
import { defineConfig } from 'vite';
import react from '@vitejs/plugin-react';
// https://vitejs.dev/config/
export default defineConfig({
plugins: [react()],
build: {
outDir: '../',
rollupOptions: {
input: {
main: 'index.html',
},
},
lib: {
name: 'agent-scheduler',
entry: 'src/extension/index.ts',
fileName: 'javascript/extension',
formats: ['es']
},
},
});

15
ui/vite.extension.ts Normal file
View File

@ -0,0 +1,15 @@
import { defineConfig } from 'vite';
// https://vitejs.dev/config/
export default defineConfig({
build: {
outDir: '../',
copyPublicDir: false,
lib: {
name: 'agent-scheduler',
entry: 'src/extension/index.ts',
fileName: 'javascript/extension',
formats: ['es']
},
},
});

1904
ui/yarn.lock Normal file

File diff suppressed because it is too large Load Diff