chore: feature updates
- add a flag to enable/disable queue auto processing - add queue button placement setting - add a flag to hide the custom checkpoint select - rewrite frontend code in typescript - extract serialization logic to task_helpers - bugs fixingpull/7/head
parent
21d150a0bd
commit
4f7339468c
|
|
@ -37,4 +37,3 @@ notification.mp3
|
|||
*.db
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
tailwind.*
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -5,7 +5,7 @@ from gradio.routes import App
|
|||
import modules.shared as shared
|
||||
from modules import progress, script_callbacks, sd_samplers
|
||||
|
||||
from scripts.db import TaskStatus, AppStateKey, task_manager, state_manager
|
||||
from scripts.db import TaskStatus, task_manager
|
||||
from scripts.models import QueueStatusResponse
|
||||
from scripts.task_runner import TaskRunner, get_instance
|
||||
from scripts.helpers import log
|
||||
|
|
@ -29,8 +29,8 @@ def regsiter_apis(app: App):
|
|||
task_args = TaskRunner.instance.parse_task_args(
|
||||
task.params, task.script_params, deserialization=False
|
||||
)
|
||||
named_args = task_args["named_args"]
|
||||
named_args["checkpoint"] = task_args["checkpoint"]
|
||||
named_args = task_args.named_args
|
||||
named_args["checkpoint"] = task_args.checkpoint
|
||||
sampler_index = named_args.get("sampler_index", None)
|
||||
if sampler_index is not None:
|
||||
named_args["sampler_name"] = sd_samplers.samplers[
|
||||
|
|
@ -103,21 +103,24 @@ def regsiter_apis(app: App):
|
|||
|
||||
@app.post("/agent-scheduler/v1/pause")
|
||||
def pause_queue():
|
||||
state_manager.set_value(AppStateKey.QueueState, "paused")
|
||||
# state_manager.set_value(AppStateKey.QueueState, "paused")
|
||||
shared.opts.queue_paused = True
|
||||
return {"success": True, "message": f"Queue is paused"}
|
||||
|
||||
@app.post("/agent-scheduler/v1/resume")
|
||||
def resume_queue():
|
||||
state_manager.set_value(AppStateKey.QueueState, "running")
|
||||
# state_manager.set_value(AppStateKey.QueueState, "running")
|
||||
shared.opts.queue_paused = False
|
||||
TaskRunner.instance.execute_pending_tasks_threading()
|
||||
return {"success": True, "message": f"Queue is resumed"}
|
||||
|
||||
|
||||
def on_app_started(block, app: App):
|
||||
global task_runner
|
||||
task_runner = get_instance(block)
|
||||
if block is not None:
|
||||
global task_runner
|
||||
task_runner = get_instance(block)
|
||||
|
||||
regsiter_apis(app)
|
||||
regsiter_apis(app)
|
||||
|
||||
|
||||
script_callbacks.on_app_started(on_app_started)
|
||||
|
|
|
|||
|
|
@ -27,11 +27,15 @@ def init():
|
|||
|
||||
inspector = inspect(engine)
|
||||
with engine.connect() as conn:
|
||||
# check if table task has column result and add it if not
|
||||
task_columns = inspector.get_columns("task")
|
||||
# add result column
|
||||
if not any(col["name"] == "result" for col in task_columns):
|
||||
conn.execute(text("ALTER TABLE task ADD COLUMN result TEXT"))
|
||||
|
||||
# add api_task_id column
|
||||
if not any(col["name"] == "api_task_id" for col in task_columns):
|
||||
conn.execute(text("ALTER TABLE task ADD COLUMN api_task_id VARCHAR(64)"))
|
||||
|
||||
params_column = next(col for col in task_columns if col["name"] == "params")
|
||||
if version > "1" and not isinstance(params_column["type"], Text):
|
||||
transaction = conn.begin()
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ class Task(TaskModel):
|
|||
def __init__(
|
||||
self,
|
||||
id: str = "",
|
||||
api_task_id: str = None,
|
||||
type: str = "unknown",
|
||||
params: str = "",
|
||||
script_params: bytes = b"",
|
||||
|
|
@ -35,6 +36,7 @@ class Task(TaskModel):
|
|||
|
||||
super().__init__(
|
||||
id=id,
|
||||
api_task_id=api_task_id,
|
||||
type=type,
|
||||
params=params,
|
||||
status=status,
|
||||
|
|
@ -44,6 +46,7 @@ class Task(TaskModel):
|
|||
updated_at=created_at,
|
||||
)
|
||||
self.id: str = id
|
||||
self.api_task_id: str = api_task_id
|
||||
self.type: str = type
|
||||
self.params: str = params
|
||||
self.script_params: bytes = script_params
|
||||
|
|
@ -60,6 +63,7 @@ class Task(TaskModel):
|
|||
def from_table(table: "TaskTable"):
|
||||
return Task(
|
||||
id=table.id,
|
||||
api_task_id=table.api_task_id,
|
||||
type=table.type,
|
||||
params=table.params,
|
||||
script_params=table.script_params,
|
||||
|
|
@ -72,6 +76,7 @@ class Task(TaskModel):
|
|||
def to_table(self):
|
||||
return TaskTable(
|
||||
id=self.id,
|
||||
api_task_id=self.api_task_id,
|
||||
type=self.type,
|
||||
params=self.params,
|
||||
script_params=self.script_params,
|
||||
|
|
@ -84,6 +89,7 @@ class TaskTable(Base):
|
|||
__tablename__ = "task"
|
||||
|
||||
id = Column(String(64), primary_key=True)
|
||||
api_task_id = Column(String(64), nullable=True)
|
||||
type = Column(String(20), nullable=False) # txt2img or img2txt
|
||||
params = Column(Text, nullable=False) # task args
|
||||
script_params = Column(LargeBinary, nullable=False) # script args
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class QueueStatusAPI(BaseModel):
|
|||
|
||||
class TaskModel(BaseModel):
|
||||
id: str = Field(title="Task Id")
|
||||
api_task_id: Optional[str] = Field(title="API Task Id", default=None)
|
||||
type: str = Field(title="Task Type", description="Either txt2img or img2img")
|
||||
status: str = Field(title="Task Status", description="Either pending, running, done or failed")
|
||||
params: str = Field(title="Task Parameters", description="The parameters of the task in JSON format")
|
||||
|
|
@ -34,6 +35,7 @@ class TaskModel(BaseModel):
|
|||
datetime: convert_datetime_to_iso_8601_with_z_suffix
|
||||
}
|
||||
|
||||
|
||||
class QueueStatusResponse(BaseModel):
|
||||
current_task_id: Optional[str] = Field(title="Current Task Id", description="The on progress task id")
|
||||
pending_tasks: List[TaskModel] = Field(title="Pending Tasks", description="The pending tasks in the queue")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,441 @@
|
|||
import os
|
||||
import io
|
||||
import zlib
|
||||
import base64
|
||||
import inspect
|
||||
import requests
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from PIL import Image, ImageOps, ImageChops, ImageEnhance, ImageFilter
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from modules import sd_samplers, scripts
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules.sd_models import CheckpointInfo, get_closet_checkpoint_match
|
||||
from modules.txt2img import txt2img
|
||||
from modules.img2img import img2img
|
||||
from modules.api.models import (
|
||||
StableDiffusionTxt2ImgProcessingAPI,
|
||||
StableDiffusionImg2ImgProcessingAPI,
|
||||
)
|
||||
|
||||
from scripts.helpers import log
|
||||
|
||||
img2img_image_args_by_mode: dict[int, list[list[str]]] = {
|
||||
0: [["init_img"]],
|
||||
1: [["sketch"]],
|
||||
2: [["init_img_with_mask", "image"], ["init_img_with_mask", "mask"]],
|
||||
3: [["inpaint_color_sketch"], ["inpaint_color_sketch_orig"]],
|
||||
4: [["init_img_inpaint"], ["init_mask_inpaint"]],
|
||||
}
|
||||
|
||||
|
||||
class ControlNetImage(BaseModel):
|
||||
image: str # base64 or url
|
||||
mask: Optional[str] = None # base64 or url
|
||||
|
||||
|
||||
class ControlNetUnit(BaseModel):
|
||||
enabled: Optional[bool] = True
|
||||
module: Optional[str] = "none"
|
||||
model: Optional[str] = None
|
||||
image: ControlNetImage = None
|
||||
weight: Optional[float] = 1.0
|
||||
resize_mode: Optional[str] = None
|
||||
low_vram: Optional[bool] = False
|
||||
processor_res: Optional[int] = 512
|
||||
threshold_a: Optional[float] = 64
|
||||
threshold_b: Optional[float] = 64
|
||||
guidance_start: Optional[float] = 0.0
|
||||
guidance_end: Optional[float] = 1.0
|
||||
pixel_perfect: Optional[bool] = False
|
||||
control_mode: Optional[str] = "Balanced"
|
||||
|
||||
|
||||
class BaseApiTaskArgs(BaseModel):
|
||||
task_id: str = Field(exclude=True)
|
||||
model_hash: str = Field(exclude=True)
|
||||
prompt: Optional[str] = ""
|
||||
styles: Optional[List[str]] = []
|
||||
negative_prompt: Optional[str] = ""
|
||||
seed: Optional[int] = -1
|
||||
subseed: Optional[int] = 1
|
||||
subseed_strength: Optional[int] = 0
|
||||
seed_resize_from_h: Optional[int] = -1
|
||||
seed_resize_from_w: Optional[int] = -1
|
||||
sampler_name: Optional[str] = "DPM++ 2M Karras"
|
||||
n_iter: Optional[int] = 1
|
||||
batch_size: Optional[int] = 1
|
||||
steps: Optional[int] = 20
|
||||
cfg_scale: Optional[int] = 7.0
|
||||
restore_faces: Optional[bool] = False
|
||||
tiling: Optional[bool] = False
|
||||
width: Optional[int] = 512
|
||||
height: Optional[int] = 512
|
||||
script_name: Optional[str] = None
|
||||
controlnet_args: Optional[List[ControlNetUnit]] = Field(exclude=True, default=[])
|
||||
override_settings: Optional[dict] = Field(default={})
|
||||
|
||||
|
||||
class Txt2ImgApiTaskArgs(BaseApiTaskArgs):
|
||||
enable_hr: Optional[bool] = False
|
||||
denoising_strength: Optional[int] = 0
|
||||
hr_scale: Optional[int] = 1
|
||||
hr_upscaler: Optional[str] = "Latent"
|
||||
hr_second_pass_steps: Optional[int] = 0
|
||||
hr_resize_x: Optional[int] = 0
|
||||
hr_resize_y: Optional[int] = 0
|
||||
|
||||
|
||||
class Img2ImgApiTaskArgs(BaseApiTaskArgs):
|
||||
init_images: List[str]
|
||||
mask: Optional[str] = None
|
||||
resize_mode: Optional[int] = 0
|
||||
denoising_strength: Optional[int] = 0.75
|
||||
mask_blur: Optional[int] = 4
|
||||
inpainting_fill: Optional[int] = 0
|
||||
inpaint_full_res: Optional[bool] = True
|
||||
inpaint_full_res_padding: Optional[int] = 0
|
||||
inpainting_mask_invert: Optional[int] = 0
|
||||
initial_noise_multiplier: Optional[float] = 0.0
|
||||
|
||||
|
||||
def load_image_from_url(url: str):
|
||||
try:
|
||||
response = requests.get(url)
|
||||
buffer = io.BytesIO(response.content)
|
||||
return Image.open(buffer)
|
||||
except Exception as e:
|
||||
log.error(f"[AgentScheduler] Error downloading image from url: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def load_image(image: str):
|
||||
if not isinstance(image, str):
|
||||
return image
|
||||
|
||||
pil_image = None
|
||||
if os.path.exists(image):
|
||||
pil_image = Image.open(image)
|
||||
elif image.startswith(("http://", "https://")):
|
||||
pil_image = load_image_from_url(image)
|
||||
|
||||
return pil_image
|
||||
|
||||
|
||||
def load_image_to_nparray(image: str):
|
||||
pil_image = load_image(image)
|
||||
|
||||
return (
|
||||
np.array(pil_image).astype("uint8")
|
||||
if isinstance(pil_image, Image.Image)
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
def encode_pil_to_base64(image: Image.Image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
image.save(output_bytes, format="PNG")
|
||||
bytes_data = output_bytes.getvalue()
|
||||
return base64.b64encode(bytes_data).decode("utf-8")
|
||||
|
||||
|
||||
def load_image_to_base64(image: str):
|
||||
pil_image = load_image(image)
|
||||
|
||||
if not isinstance(pil_image, Image.Image):
|
||||
return image
|
||||
|
||||
return encode_pil_to_base64(pil_image)
|
||||
|
||||
|
||||
def __serialize_image(image):
|
||||
if isinstance(image, np.ndarray):
|
||||
shape = image.shape
|
||||
data = base64.b64encode(zlib.compress(image.tobytes())).decode()
|
||||
return {"shape": shape, "data": data, "cls": "ndarray"}
|
||||
elif isinstance(image, Image.Image):
|
||||
size = image.size
|
||||
mode = image.mode
|
||||
data = base64.b64encode(zlib.compress(image.tobytes())).decode()
|
||||
return {
|
||||
"size": size,
|
||||
"mode": mode,
|
||||
"data": data,
|
||||
"cls": "Image",
|
||||
}
|
||||
else:
|
||||
return image
|
||||
|
||||
|
||||
def __deserialize_image(image_str):
|
||||
if isinstance(image_str, dict) and image_str.get("cls", None):
|
||||
cls = image_str["cls"]
|
||||
data = zlib.decompress(base64.b64decode(image_str["data"]))
|
||||
|
||||
if cls == "ndarray":
|
||||
shape = tuple(image_str["shape"])
|
||||
image = np.frombuffer(data, dtype=np.uint8)
|
||||
return image.reshape(shape)
|
||||
else:
|
||||
size = tuple(image_str["size"])
|
||||
mode = image_str["mode"]
|
||||
return Image.frombytes(mode, size, data)
|
||||
else:
|
||||
return image_str
|
||||
|
||||
|
||||
def serialize_img2img_image_args(args: dict):
|
||||
for mode, image_args in img2img_image_args_by_mode.items():
|
||||
for keys in image_args:
|
||||
if mode != args["mode"]:
|
||||
# set None to unused image args to save space
|
||||
args[keys[0]] = None
|
||||
elif len(keys) == 1:
|
||||
image = args.get(keys[0], None)
|
||||
args[keys[0]] = __serialize_image(image)
|
||||
else:
|
||||
value = args.get(keys[0], {})
|
||||
image = value.get(keys[1], None)
|
||||
value[keys[1]] = __serialize_image(image)
|
||||
args[keys[0]] = value
|
||||
|
||||
|
||||
def deserialize_img2img_image_args(args: dict):
|
||||
for mode, image_args in img2img_image_args_by_mode.items():
|
||||
if mode != args["mode"]:
|
||||
continue
|
||||
|
||||
for keys in image_args:
|
||||
if len(keys) == 1:
|
||||
image = args.get(keys[0], None)
|
||||
args[keys[0]] = __deserialize_image(image)
|
||||
else:
|
||||
value = args.get(keys[0], {})
|
||||
image = value.get(keys[1], None)
|
||||
value[keys[1]] = __deserialize_image(image)
|
||||
args[keys[0]] = value
|
||||
|
||||
|
||||
def serialize_controlnet_args(cnet_unit):
|
||||
args: dict = cnet_unit.__dict__
|
||||
args["is_cnet"] = True
|
||||
for k, v in args.items():
|
||||
if k == "image" and v is not None:
|
||||
args[k] = {
|
||||
"image": __serialize_image(v["image"]),
|
||||
"mask": __serialize_image(v["mask"])
|
||||
if v.get("mask", None) is not None
|
||||
else None,
|
||||
}
|
||||
if isinstance(v, Enum):
|
||||
args[k] = v.value
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def deserialize_controlnet_args(args: dict):
|
||||
# args.pop("is_cnet", None)
|
||||
for k, v in args.items():
|
||||
if k == "image" and v is not None:
|
||||
args[k] = {
|
||||
"image": __deserialize_image(v["image"]),
|
||||
"mask": __deserialize_image(v["mask"])
|
||||
if v.get("mask", None) is not None
|
||||
else None,
|
||||
}
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def map_ui_task_args_list_to_named_args(
|
||||
args: list, is_img2img: bool, checkpoint: str = None
|
||||
):
|
||||
args_name = []
|
||||
if is_img2img:
|
||||
args_name = inspect.getfullargspec(img2img).args
|
||||
else:
|
||||
args_name = inspect.getfullargspec(txt2img).args
|
||||
|
||||
named_args = dict(zip(args_name, args[0 : len(args_name)]))
|
||||
script_args = args[len(args_name) :]
|
||||
if checkpoint is not None:
|
||||
override_settings_texts = named_args.get("override_settings_texts", [])
|
||||
override_settings_texts.append("Model hash: " + checkpoint)
|
||||
named_args["override_settings_texts"] = override_settings_texts
|
||||
|
||||
return (
|
||||
named_args,
|
||||
script_args,
|
||||
)
|
||||
|
||||
|
||||
def map_ui_task_args_to_api_task_args(
|
||||
named_args: dict, script_args: list, is_img2img: bool
|
||||
):
|
||||
api_task_args: dict = named_args.copy()
|
||||
|
||||
prompt_styles = api_task_args.pop("prompt_styles", [])
|
||||
api_task_args["styles"] = prompt_styles
|
||||
|
||||
sampler_index = api_task_args.pop("sampler_index", 0)
|
||||
api_task_args["sampler_name"] = sd_samplers.samplers[sampler_index].name
|
||||
|
||||
override_settings_texts = api_task_args.pop("override_settings_texts", [])
|
||||
api_task_args["override_settings"] = create_override_settings_dict(
|
||||
override_settings_texts
|
||||
)
|
||||
|
||||
if is_img2img:
|
||||
mode = api_task_args.pop("mode", 0)
|
||||
for arg_mode, image_args in img2img_image_args_by_mode.items():
|
||||
if mode != arg_mode:
|
||||
for keys in image_args:
|
||||
api_task_args.pop(keys[0], None)
|
||||
|
||||
# the logic below is copied from modules/img2img.py
|
||||
if mode == 0:
|
||||
image = api_task_args.pop("init_img").convert("RGB")
|
||||
mask = None
|
||||
elif mode == 1:
|
||||
image = api_task_args.pop("sketch").convert("RGB")
|
||||
mask = None
|
||||
elif mode == 2:
|
||||
init_img_with_mask: dict = api_task_args.pop("init_img_with_mask")
|
||||
image = init_img_with_mask.get("image").convert("RGB")
|
||||
mask = init_img_with_mask.get("mask")
|
||||
alpha_mask = (
|
||||
ImageOps.invert(image.split()[-1])
|
||||
.convert("L")
|
||||
.point(lambda x: 255 if x > 0 else 0, mode="1")
|
||||
)
|
||||
mask = ImageChops.lighter(alpha_mask, mask.convert("L")).convert("L")
|
||||
elif mode == 3:
|
||||
image = api_task_args.pop("inpaint_color_sketch")
|
||||
orig = api_task_args.pop("inpaint_color_sketch_orig") or image
|
||||
mask_alpha = api_task_args.pop("mask_alpha", 0)
|
||||
mask_blur = api_task_args.get("mask_blur", 4)
|
||||
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
||||
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
||||
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
||||
blur = ImageFilter.GaussianBlur(mask_blur)
|
||||
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||
image = image.convert("RGB")
|
||||
elif mode == 4:
|
||||
image = api_task_args.pop("init_img_inpaint")
|
||||
mask = api_task_args.pop("init_mask_inpaint")
|
||||
else:
|
||||
raise Exception(f"Batch mode is not supported yet")
|
||||
|
||||
image = ImageOps.exif_transpose(image)
|
||||
api_task_args["init_images"] = [encode_pil_to_base64(image)]
|
||||
api_task_args["mask"] = encode_pil_to_base64(mask) if mask is not None else None
|
||||
|
||||
selected_scale_tab = api_task_args.pop("selected_scale_tab", 0)
|
||||
scale_by = api_task_args.pop("scale_by", 1)
|
||||
if selected_scale_tab == 1:
|
||||
api_task_args["width"] = int(image.width * scale_by)
|
||||
api_task_args["height"] = int(image.height * scale_by)
|
||||
else:
|
||||
hr_sampler_index = api_task_args.pop("hr_sampler_index", 0)
|
||||
api_task_args["hr_sampler_name"] = (
|
||||
sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name
|
||||
if hr_sampler_index != 0
|
||||
else None
|
||||
)
|
||||
|
||||
# script
|
||||
script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
|
||||
script_id = script_args[0]
|
||||
if script_id == 0:
|
||||
api_task_args["script_name"] = None
|
||||
api_task_args["script_args"] = []
|
||||
else:
|
||||
script = script_runner.selectable_scripts[script_id - 1]
|
||||
api_task_args["script_name"] = script.title.lower()
|
||||
api_task_args["script_args"] = script_args[script.args_from : script.args_to]
|
||||
|
||||
# alwayson scripts
|
||||
alwayson_scripts = api_task_args.get("alwayson_scripts", None)
|
||||
if not alwayson_scripts or not isinstance(alwayson_scripts, dict):
|
||||
alwayson_scripts = {}
|
||||
api_task_args["alwayson_scripts"] = alwayson_scripts
|
||||
|
||||
for script in script_runner.alwayson_scripts:
|
||||
alwayson_script_args = script_args[script.args_from : script.args_to]
|
||||
if script.title.lower() == "controlnet":
|
||||
for i, cnet_args in enumerate(alwayson_script_args):
|
||||
alwayson_script_args[i] = serialize_controlnet_args(cnet_args)
|
||||
|
||||
alwayson_scripts[script.title.lower()] = {"args": alwayson_script_args}
|
||||
|
||||
return api_task_args
|
||||
|
||||
|
||||
def serialize_api_task_args(params: dict, is_img2img: bool):
|
||||
# pop out custom params
|
||||
model_hash = params.pop("model_hash", None)
|
||||
controlnet_args = params.pop("controlnet_args", None)
|
||||
|
||||
args = (
|
||||
StableDiffusionImg2ImgProcessingAPI(**params)
|
||||
if is_img2img
|
||||
else StableDiffusionTxt2ImgProcessingAPI(**params)
|
||||
)
|
||||
|
||||
if args.override_settings is None:
|
||||
args.override_settings = {}
|
||||
|
||||
if model_hash is None:
|
||||
model_hash = args.override_settings.get("sd_model_checkpoint", None)
|
||||
|
||||
if model_hash is None:
|
||||
log.error("[AgentScheduler] API task must supply model hash")
|
||||
return
|
||||
|
||||
checkpoint: CheckpointInfo = get_closet_checkpoint_match(model_hash)
|
||||
if not checkpoint:
|
||||
log.warn(f"[AgentScheduler] No checkpoint found for model hash {model_hash}")
|
||||
return
|
||||
args.override_settings["sd_model_checkpoint"] = checkpoint.title
|
||||
|
||||
# load images from url or file if needed
|
||||
if is_img2img:
|
||||
init_images = args.init_images
|
||||
for i, image in enumerate(init_images):
|
||||
init_images[i] = load_image_to_base64(image)
|
||||
|
||||
args.mask = load_image_to_base64(args.mask)
|
||||
|
||||
# handle custom controlnet args
|
||||
if controlnet_args is not None:
|
||||
if args.alwayson_scripts is None:
|
||||
args.alwayson_scripts = {}
|
||||
|
||||
controlnets = []
|
||||
for cnet in controlnet_args:
|
||||
enabled = cnet.get("enabled", True)
|
||||
cnet_image = cnet.get("image", None)
|
||||
|
||||
if not enabled:
|
||||
continue
|
||||
if not isinstance(cnet_image, dict):
|
||||
log.error(f"[AgentScheduler] Controlnet image is required")
|
||||
continue
|
||||
|
||||
image = cnet_image.get("image", None)
|
||||
mask = cnet_image.get("mask", None)
|
||||
if image is None:
|
||||
log.error(f"[AgentScheduler] Controlnet image is required")
|
||||
continue
|
||||
|
||||
# load controlnet images from url or file if needed
|
||||
cnet_image["image"] = load_image_to_base64(image)
|
||||
cnet_image["mask"] = load_image_to_base64(mask)
|
||||
controlnets.append(cnet)
|
||||
|
||||
if len(controlnets) > 0:
|
||||
args.alwayson_scripts["controlnet"] = {"args": controlnets}
|
||||
|
||||
return args.dict()
|
||||
|
|
@ -1,18 +1,13 @@
|
|||
import io
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import zlib
|
||||
import base64
|
||||
import pickle
|
||||
import inspect
|
||||
import traceback
|
||||
import threading
|
||||
import numpy as np
|
||||
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from PIL import Image, PngImagePlugin
|
||||
from typing import Any, Callable, Union
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
|
@ -20,22 +15,29 @@ from modules import progress, shared, script_callbacks
|
|||
from modules.call_queue import queue_lock, wrap_gradio_call
|
||||
from modules.txt2img import txt2img
|
||||
from modules.img2img import img2img
|
||||
from modules.api.api import Api, encode_pil_to_base64
|
||||
from modules.api.api import Api
|
||||
from modules.api.models import (
|
||||
StableDiffusionTxt2ImgProcessingAPI,
|
||||
StableDiffusionImg2ImgProcessingAPI,
|
||||
)
|
||||
|
||||
from scripts.db import TaskStatus, AppStateKey, Task, task_manager, state_manager
|
||||
from scripts.db import TaskStatus, Task, task_manager
|
||||
from scripts.helpers import log, detect_control_net, get_component_by_elem_id
|
||||
from scripts.task_helpers import (
|
||||
serialize_img2img_image_args,
|
||||
deserialize_img2img_image_args,
|
||||
serialize_controlnet_args,
|
||||
deserialize_controlnet_args,
|
||||
map_ui_task_args_list_to_named_args,
|
||||
)
|
||||
|
||||
img2img_image_args_by_mode: dict[int, list[list[str]]] = {
|
||||
0: [["init_img"]],
|
||||
1: [["sketch"]],
|
||||
2: [["init_img_with_mask", "image"], ["init_img_with_mask", "mask"]],
|
||||
3: [["inpaint_color_sketch"], ["inpaint_color_sketch_orig"]],
|
||||
4: [["init_img_inpaint"], ["init_mask_inpaint"]],
|
||||
}
|
||||
|
||||
class ParsedTaskArgs(BaseModel):
|
||||
args: list[Any]
|
||||
named_args: dict[str, Any]
|
||||
script_args: list[Any]
|
||||
checkpoint: str
|
||||
is_ui: bool
|
||||
|
||||
|
||||
class TaskRunner:
|
||||
|
|
@ -75,103 +77,22 @@ class TaskRunner:
|
|||
|
||||
@property
|
||||
def paused(self) -> bool:
|
||||
return state_manager.get_value(AppStateKey.QueueState) == "paused"
|
||||
|
||||
def __serialize_image(self, image):
|
||||
if isinstance(image, np.ndarray):
|
||||
shape = image.shape
|
||||
data = base64.b64encode(zlib.compress(image.tobytes())).decode()
|
||||
return {"shape": shape, "data": data, "cls": "ndarray"}
|
||||
elif isinstance(image, Image.Image):
|
||||
size = image.size
|
||||
mode = image.mode
|
||||
data = base64.b64encode(zlib.compress(image.tobytes())).decode()
|
||||
return {
|
||||
"size": size,
|
||||
"mode": mode,
|
||||
"data": data,
|
||||
"cls": "Image",
|
||||
}
|
||||
else:
|
||||
return image
|
||||
|
||||
def __deserialize_image(self, image_str):
|
||||
if isinstance(image_str, dict) and image_str.get("cls", None):
|
||||
cls = image_str["cls"]
|
||||
data = zlib.decompress(base64.b64decode(image_str["data"]))
|
||||
|
||||
if cls == "ndarray":
|
||||
shape = tuple(image_str["shape"])
|
||||
image = np.frombuffer(data, dtype=np.uint8)
|
||||
return image.reshape(shape)
|
||||
else:
|
||||
size = tuple(image_str["size"])
|
||||
mode = image_str["mode"]
|
||||
return Image.frombytes(mode, size, data)
|
||||
else:
|
||||
return image_str
|
||||
|
||||
def __serialize_img2img_images(self, args: dict, image_args: list):
|
||||
for keys in image_args:
|
||||
if len(keys) == 1:
|
||||
image = args.get(keys[0], None)
|
||||
args[keys[0]] = self.__serialize_image(image)
|
||||
else:
|
||||
value = args.get(keys[0], {})
|
||||
image = value.get(keys[1], None)
|
||||
value[keys[1]] = self.__serialize_image(image)
|
||||
args[keys[0]] = value
|
||||
|
||||
def __deserialize_img2img_images(self, args: dict, image_args: list):
|
||||
for keys in image_args:
|
||||
if len(keys) == 1:
|
||||
image = args.get(keys[0], None)
|
||||
args[keys[0]] = self.__deserialize_image(image)
|
||||
else:
|
||||
value = args.get(keys[0], {})
|
||||
image = value.get(keys[1], None)
|
||||
value[keys[1]] = self.__deserialize_image(image)
|
||||
args[keys[0]] = value
|
||||
return shared.opts.queue_paused
|
||||
|
||||
def __serialize_ui_task_args(self, is_img2img: bool, *args, checkpoint: str = None):
|
||||
args_name = []
|
||||
if is_img2img:
|
||||
args_name = inspect.getfullargspec(img2img).args
|
||||
else:
|
||||
args_name = inspect.getfullargspec(txt2img).args
|
||||
|
||||
args = list(args)
|
||||
named_args = dict(zip(args_name, args[0 : len(args_name)]))
|
||||
script_args = args[len(args_name) :]
|
||||
if checkpoint:
|
||||
override_settings_texts = named_args.get("override_settings_texts", [])
|
||||
override_settings_texts.append("Model hash: " + checkpoint)
|
||||
named_args["override_settings_texts"] = override_settings_texts
|
||||
named_args, script_args = map_ui_task_args_list_to_named_args(
|
||||
list(args), is_img2img, checkpoint=checkpoint
|
||||
)
|
||||
|
||||
# loop through named_args and serialize images
|
||||
if is_img2img:
|
||||
for mode, image_args in img2img_image_args_by_mode.items():
|
||||
if mode == named_args["mode"]:
|
||||
self.__serialize_img2img_images(named_args, image_args)
|
||||
else:
|
||||
# set None to unused image args to save space
|
||||
for keys in image_args:
|
||||
named_args[keys[0]] = None
|
||||
serialize_img2img_image_args(named_args)
|
||||
|
||||
# loop through script_args and serialize controlnets
|
||||
if self.UiControlNetUnit is not None:
|
||||
for i, a in enumerate(script_args):
|
||||
if isinstance(a, self.UiControlNetUnit):
|
||||
script_args[i] = a.__dict__
|
||||
script_args[i]["is_cnet"] = True
|
||||
for k, v in script_args[i].items():
|
||||
if k == "image" and v is not None:
|
||||
script_args[i][k] = {
|
||||
"image": self.__serialize_image(v["image"]),
|
||||
"mask": self.__serialize_image(v["mask"]),
|
||||
}
|
||||
if isinstance(v, Enum):
|
||||
script_args[i][k] = v.value
|
||||
script_args[i] = serialize_controlnet_args(a)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
|
|
@ -186,6 +107,7 @@ class TaskRunner:
|
|||
def __serialize_api_task_args(
|
||||
self, is_img2img: bool, script_args: list = [], **named_args
|
||||
):
|
||||
# serialization steps are done in task_helpers.register_api_task
|
||||
override_settings = named_args.get("override_settings", {})
|
||||
checkpoint = override_settings.get("sd_model_checkpoint", None)
|
||||
|
||||
|
|
@ -204,21 +126,17 @@ class TaskRunner:
|
|||
):
|
||||
# loop through image_args and deserialize images
|
||||
if is_img2img:
|
||||
for mode, image_args in img2img_image_args_by_mode.items():
|
||||
if mode == named_args["mode"]:
|
||||
self.__deserialize_img2img_images(named_args, image_args)
|
||||
deserialize_img2img_image_args(named_args)
|
||||
|
||||
# loop through script_args and deserialize controlnets
|
||||
if self.UiControlNetUnit is not None:
|
||||
for i, arg in enumerate(script_args):
|
||||
if isinstance(arg, dict) and arg.get("is_cnet", False):
|
||||
arg.pop("is_cnet")
|
||||
for k, v in arg.items():
|
||||
if k == "image" and v is not None:
|
||||
arg[k] = {
|
||||
"image": self.__deserialize_image(v["image"]),
|
||||
"mask": self.__deserialize_image(v["mask"]),
|
||||
}
|
||||
script_args[i] = deserialize_controlnet_args(arg)
|
||||
|
||||
def __deserialize_api_task_args(self, is_img2img: bool, named_args: dict):
|
||||
# API task use base64 images as input, no need to deserialize
|
||||
pass
|
||||
|
||||
def parse_task_args(
|
||||
self, params: str, script_params: bytes, deserialization: bool = True
|
||||
|
|
@ -237,16 +155,18 @@ class TaskRunner:
|
|||
|
||||
if is_ui and deserialization:
|
||||
self.__deserialize_ui_task_args(is_img2img, named_args, script_args)
|
||||
elif deserialization:
|
||||
self.__deserialize_api_task_args(is_img2img, named_args)
|
||||
|
||||
args = list(named_args.values()) + script_args
|
||||
|
||||
return {
|
||||
"args": args,
|
||||
"named_args": named_args,
|
||||
"script_args": script_args,
|
||||
"checkpoint": checkpoint,
|
||||
"is_ui": is_ui,
|
||||
}
|
||||
return ParsedTaskArgs(
|
||||
args=args,
|
||||
named_args=named_args,
|
||||
script_args=script_args,
|
||||
checkpoint=checkpoint,
|
||||
is_ui=is_ui,
|
||||
)
|
||||
|
||||
def register_ui_task(
|
||||
self, task_id: str, is_img2img: bool, *args, checkpoint: str = None
|
||||
|
|
@ -263,13 +183,19 @@ class TaskRunner:
|
|||
)
|
||||
self.__total_pending_tasks += 1
|
||||
|
||||
def register_api_task(self, task_id: str, is_img2img: bool, args: dict):
|
||||
def register_api_task(
|
||||
self, task_id: str, api_task_id: str, is_img2img: bool, args: dict
|
||||
):
|
||||
progress.add_task_to_queue(task_id)
|
||||
|
||||
args = args.copy()
|
||||
args.update({"save_images": True, "send_images": False})
|
||||
params = self.__serialize_api_task_args(is_img2img, **args)
|
||||
|
||||
task_type = "img2img" if is_img2img else "txt2img"
|
||||
task_manager.add_task(Task(id=task_id, type=task_type, params=params))
|
||||
task_manager.add_task(
|
||||
Task(id=task_id, api_task_id=api_task_id, type=task_type, params=params)
|
||||
)
|
||||
|
||||
self.__run_callbacks(
|
||||
"task_registered", task_id, is_img2img=is_img2img, is_ui=False, args=params
|
||||
|
|
@ -279,7 +205,11 @@ class TaskRunner:
|
|||
def execute_task(self, task: Task, get_next_task: Callable):
|
||||
while True:
|
||||
if self.dispose:
|
||||
sys.exit(0)
|
||||
break
|
||||
|
||||
if self.paused:
|
||||
log.info("[AgentScheduler] Runner is paused")
|
||||
break
|
||||
|
||||
if progress.current_task is None:
|
||||
task_id = task.id
|
||||
|
|
@ -290,7 +220,11 @@ class TaskRunner:
|
|||
task.params,
|
||||
task.script_params,
|
||||
)
|
||||
task_meta = {"is_img2img": is_img2img, "is_ui": task_args["is_ui"]}
|
||||
task_meta = {
|
||||
"is_img2img": is_img2img,
|
||||
"is_ui": task_args.is_ui,
|
||||
"api_task_id": task.api_task_id,
|
||||
}
|
||||
|
||||
self.__saved_images_path = []
|
||||
self.__run_callbacks("task_started", task_id, **task_meta)
|
||||
|
|
@ -333,7 +267,7 @@ class TaskRunner:
|
|||
|
||||
task = get_next_task()
|
||||
if not task:
|
||||
sys.exit(0)
|
||||
break
|
||||
|
||||
def execute_pending_tasks_threading(self):
|
||||
if self.paused:
|
||||
|
|
@ -366,19 +300,19 @@ class TaskRunner:
|
|||
return [
|
||||
task.id,
|
||||
task.type,
|
||||
json.dumps(task_args["named_args"]),
|
||||
json.dumps(task_args.named_args),
|
||||
task.created_at.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
]
|
||||
|
||||
def __execute_task(self, task_id: str, is_img2img: bool, task_args: dict):
|
||||
if task_args["is_ui"]:
|
||||
return self.__execute_ui_task(task_id, is_img2img, *task_args["args"])
|
||||
def __execute_task(self, task_id: str, is_img2img: bool, task_args: ParsedTaskArgs):
|
||||
if task_args.is_ui:
|
||||
return self.__execute_ui_task(task_id, is_img2img, *task_args.args)
|
||||
else:
|
||||
return self.__execute_api_task(
|
||||
task_id,
|
||||
is_img2img,
|
||||
script_args=task_args["script_args"],
|
||||
**task_args["named_args"],
|
||||
script_args=task_args.script_args,
|
||||
**task_args.named_args,
|
||||
)
|
||||
|
||||
def __execute_ui_task(self, task_id: str, is_img2img: bool, *args):
|
||||
|
|
@ -483,9 +417,12 @@ class TaskRunner:
|
|||
|
||||
def get_instance(block) -> TaskRunner:
|
||||
if TaskRunner.instance is None:
|
||||
txt2img_submit_button = get_component_by_elem_id(block, "txt2img_generate")
|
||||
UiControlNetUnit = detect_control_net(block, txt2img_submit_button)
|
||||
TaskRunner(UiControlNetUnit)
|
||||
if block is not None:
|
||||
txt2img_submit_button = get_component_by_elem_id(block, "txt2img_generate")
|
||||
UiControlNetUnit = detect_control_net(block, txt2img_submit_button)
|
||||
TaskRunner(UiControlNetUnit)
|
||||
else:
|
||||
TaskRunner()
|
||||
|
||||
def on_before_reload():
|
||||
# Tell old instance to stop
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from modules.ui import create_refresh_button
|
|||
|
||||
from scripts.task_runner import TaskRunner, get_instance
|
||||
from scripts.helpers import compare_components_with_ids, get_components_by_ids
|
||||
from scripts.db import init, state_manager, AppStateKey
|
||||
from scripts.db import init
|
||||
|
||||
task_runner: TaskRunner = None
|
||||
initialized = False
|
||||
|
|
@ -14,6 +14,10 @@ initialized = False
|
|||
checkpoint_current = "Current Checkpoint"
|
||||
checkpoint_runtime = "Runtime Checkpoint"
|
||||
|
||||
placement_above_generate = "Above Generate button"
|
||||
placement_under_generate = "Under Generate button"
|
||||
placement_between_prompt_and_generate = "Between Prompt and Generate button"
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
def __init__(self):
|
||||
|
|
@ -77,10 +81,11 @@ class Script(scripts.Script):
|
|||
outputs=fn_block.outputs,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
id_part = "img2img" if is_img2img else "txt2img"
|
||||
with root:
|
||||
with generate.parent:
|
||||
id_part = "img2img" if is_img2img else "txt2img"
|
||||
with gr.Row(elem_id=f"{id_part}_enqueue_wrapper"):
|
||||
with gr.Row(elem_id=f"{id_part}_enqueue_wrapper") as row:
|
||||
if not shared.opts.queue_button_hide_checkpoint:
|
||||
checkpoint = gr.Dropdown(
|
||||
choices=get_checkpoint_choices(),
|
||||
value=checkpoint_current,
|
||||
|
|
@ -93,13 +98,36 @@ class Script(scripts.Script):
|
|||
lambda: {"choices": get_checkpoint_choices()},
|
||||
f"refresh_{id_part}_checkpoint",
|
||||
)
|
||||
submit = gr.Button(
|
||||
"Enqueue", elem_id=f"{id_part}_enqueue", variant="primary"
|
||||
checkpoint.change(
|
||||
fn=self.on_checkpoint_changed, inputs=[checkpoint]
|
||||
)
|
||||
|
||||
checkpoint.change(fn=self.on_checkpoint_changed, inputs=[checkpoint])
|
||||
submit = gr.Button(
|
||||
"Enqueue", elem_id=f"{id_part}_enqueue", variant="primary"
|
||||
)
|
||||
submit.click(**args)
|
||||
|
||||
# relocation the enqueue button
|
||||
root.children.pop()
|
||||
if shared.opts.queue_button_placement == placement_between_prompt_and_generate:
|
||||
if is_img2img:
|
||||
# add to the iterrogate div
|
||||
parent = generate.parent.parent.parent.children[1]
|
||||
parent.add(row)
|
||||
else:
|
||||
# insert after the prompts
|
||||
parent = generate.parent.parent.parent
|
||||
row.parent = parent
|
||||
parent.children.insert(1, row)
|
||||
elif shared.opts.queue_button_placement == placement_under_generate:
|
||||
# insert after the tools div
|
||||
parent = generate.parent.parent
|
||||
parent.children.insert(2, row)
|
||||
else:
|
||||
# insert after before the generate button
|
||||
parent = generate.parent.parent
|
||||
parent.children.insert(0, row)
|
||||
|
||||
if cnet_dependency is not None:
|
||||
cnet_fn_block = next(
|
||||
fn
|
||||
|
|
@ -147,10 +175,6 @@ def get_checkpoint_choices():
|
|||
return choices
|
||||
|
||||
|
||||
def is_queue_paused():
|
||||
return state_manager.get_value(AppStateKey.QueueState) == "paused"
|
||||
|
||||
|
||||
def on_ui_tab(**_kwargs):
|
||||
global initialized
|
||||
if not initialized:
|
||||
|
|
@ -158,18 +182,25 @@ def on_ui_tab(**_kwargs):
|
|||
init()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as scheduler_tab:
|
||||
gr.Textbox(
|
||||
shared.opts.queue_button_placement,
|
||||
elem_id="agent_scheduler_queue_button_placement",
|
||||
show_label=False,
|
||||
visible=False,
|
||||
interactive=False,
|
||||
)
|
||||
with gr.Row(elem_id="agent_scheduler_pending_tasks_wrapper"):
|
||||
with gr.Column(scale=1):
|
||||
with gr.Group(elem_id="agent_scheduler_actions"):
|
||||
paused = is_queue_paused()
|
||||
paused = shared.opts.queue_paused
|
||||
|
||||
pause = gr.Button(
|
||||
gr.Button(
|
||||
"Pause",
|
||||
elem_id="agent_scheduler_action_pause",
|
||||
variant="stop",
|
||||
visible=not paused,
|
||||
)
|
||||
resume = gr.Button(
|
||||
gr.Button(
|
||||
"Resume",
|
||||
elem_id="agent_scheduler_action_resume",
|
||||
variant="primary",
|
||||
|
|
@ -195,11 +226,53 @@ def on_ui_tab(**_kwargs):
|
|||
return [(scheduler_tab, "Agent Scheduler", "agent_scheduler")]
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
section = ("agent_scheduler", "Agent Scheduler")
|
||||
shared.opts.add_option(
|
||||
"queue_paused",
|
||||
shared.OptionInfo(
|
||||
False,
|
||||
"Disable queue auto processing",
|
||||
gr.Checkbox,
|
||||
{"interactive": True},
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"queue_button_placement",
|
||||
shared.OptionInfo(
|
||||
placement_above_generate,
|
||||
"Queue button placement",
|
||||
gr.Radio,
|
||||
lambda: {
|
||||
"choices": [
|
||||
placement_above_generate,
|
||||
placement_under_generate,
|
||||
placement_between_prompt_and_generate,
|
||||
]
|
||||
},
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"queue_button_hide_checkpoint",
|
||||
shared.OptionInfo(
|
||||
True,
|
||||
"Hide the checkpoint dropdown",
|
||||
gr.Checkbox,
|
||||
{},
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def on_app_started(block, _):
|
||||
global task_runner
|
||||
task_runner = get_instance(block)
|
||||
task_runner.execute_pending_tasks_threading()
|
||||
if block is not None:
|
||||
global task_runner
|
||||
task_runner = get_instance(block)
|
||||
task_runner.execute_pending_tasks_threading()
|
||||
|
||||
|
||||
script_callbacks.on_ui_tabs(on_ui_tab)
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
script_callbacks.on_app_started(on_app_started)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
module.exports = {
|
||||
env: { browser: true, es2020: true },
|
||||
extends: [
|
||||
'eslint:recommended',
|
||||
'plugin:@typescript-eslint/recommended',
|
||||
'plugin:react-hooks/recommended',
|
||||
],
|
||||
parser: '@typescript-eslint/parser',
|
||||
parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
|
||||
plugins: ['react-refresh'],
|
||||
rules: {
|
||||
'react-refresh/only-export-components': 'warn',
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
{
|
||||
"name": "ui",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "tsc && yarn build:extension",
|
||||
"build:extension": "vite build --config vite.extension.ts",
|
||||
"lint": "eslint src --ext ts,tsx --report-unused-disable-directives --max-warnings 0",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"ag-grid-community": "^29.3.5",
|
||||
"notyf": "^3.10.0",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"rxjs": "^7.8.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^18",
|
||||
"@types/react": "^18.0.37",
|
||||
"@types/react-dom": "^18.0.11",
|
||||
"@typescript-eslint/eslint-plugin": "^5.59.0",
|
||||
"@typescript-eslint/parser": "^5.59.0",
|
||||
"@vitejs/plugin-react": "^4.0.0",
|
||||
"autoprefixer": "^10.4.14",
|
||||
"eslint": "^8.38.0",
|
||||
"eslint-plugin-react-hooks": "^4.6.0",
|
||||
"eslint-plugin-react-refresh": "^0.3.4",
|
||||
"postcss": "^8.4.24",
|
||||
"tailwindcss": "^3.3.2",
|
||||
"typescript": "^5.0.2",
|
||||
"vite": "^4.3.9"
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
module.exports = {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
@tailwind components;
|
||||
@tailwind utilities;
|
||||
|
||||
@layer components {
|
||||
.ts-search {
|
||||
@apply relative w-full max-w-xs ml-auto;
|
||||
}
|
||||
.ts-search-input {
|
||||
@apply bg-gray-50 border border-gray-300 text-gray-900 text-sm !rounded-md focus:ring-blue-500 focus:border-blue-500 block !pl-10 !p-2 w-full dark:!bg-gray-700 dark:!border-gray-600 dark:!placeholder-gray-400 dark:!text-white dark:focus:ring-blue-500 dark:focus:border-blue-500;
|
||||
}
|
||||
|
||||
.ts-search-icon {
|
||||
@apply absolute inset-y-0 left-0 flex items-center dark:text-white pl-3 pointer-events-none;
|
||||
}
|
||||
|
||||
.ts-btn-action {
|
||||
@apply inline-flex items-center !px-2 !py-1 !m-0 text-sm font-medium border focus:z-10 focus:ring-2 disabled:opacity-50 disabled:hover:!bg-transparent disabled:cursor-not-allowed;
|
||||
}
|
||||
|
||||
.ts-btn-run {
|
||||
@apply !text-green-500 hover:!text-white border-green-500 hover:bg-green-600 rounded-l-md focus:ring-green-400 dark:border-green-500 dark:hover:bg-green-600 dark:focus:ring-green-900 disabled:hover:!text-green-500;
|
||||
}
|
||||
|
||||
.ts-btn-delete {
|
||||
@apply !text-red-500 hover:!text-white border-red-600 hover:bg-red-600 rounded-r-md focus:ring-red-300 dark:border-red-500 dark:hover:bg-red-600 dark:focus:ring-red-900;
|
||||
}
|
||||
|
||||
@keyframes blink {
|
||||
from,
|
||||
to {
|
||||
opacity: 0;
|
||||
}
|
||||
50% {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
.task-running {
|
||||
@apply !text-green-500;
|
||||
animation: 1s blink ease infinite;
|
||||
}
|
||||
}
|
||||
|
||||
/* ========================================================================= */
|
||||
|
||||
#agent_scheduler_pending_tasks_wrapper {
|
||||
gap: var(--layout-gap);
|
||||
border: none;
|
||||
box-shadow: none;
|
||||
border-width: 0;
|
||||
}
|
||||
|
||||
@media (max-width: 1024px) {
|
||||
#agent_scheduler_pending_tasks_wrapper {
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
}
|
||||
|
||||
#agent_scheduler_pending_tasks_wrapper > div:last-child {
|
||||
width: 100%;
|
||||
max-width: 400px;
|
||||
}
|
||||
|
||||
@media (min-width: 1920px) {
|
||||
#agent_scheduler_pending_tasks_wrapper > div:last-child {
|
||||
max-width: 512px;
|
||||
}
|
||||
}
|
||||
|
||||
#agent_scheduler_current_task_images {
|
||||
width: 100%;
|
||||
padding-top: 100%;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
#agent_scheduler_current_task_images > div {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
#agent_scheduler_pending_tasks_wrapper {
|
||||
justify-content: flex-end;
|
||||
gap: var(--layout-gap);
|
||||
padding: 0 var(--layout-gap) var(--layout-gap) var(--layout-gap);
|
||||
}
|
||||
|
||||
#agent_scheduler_pending_tasks_wrapper > button {
|
||||
flex: 0 0 auto;
|
||||
}
|
||||
|
||||
#agent_scheduler_actions {
|
||||
display: flex;
|
||||
gap: var(--layout-gap);
|
||||
}
|
||||
|
||||
#agent_scheduler_actions > button {
|
||||
border-radius: var(--radius-lg) !important;
|
||||
}
|
||||
|
||||
.ag-theme-alpine,
|
||||
.ag-theme-alpine-dark {
|
||||
--ag-row-height: 45px;
|
||||
--ag-header-height: 45px;
|
||||
/* --ag-grid-size: 6px; */
|
||||
--ag-cell-horizontal-padding: calc(var(--ag-grid-size) * 2);
|
||||
|
||||
--body-text-color: 'inherit';
|
||||
}
|
||||
|
||||
.cell-span {
|
||||
border-bottom-color: var(--ag-border-color);
|
||||
}
|
||||
.cell-not-span {
|
||||
opacity: 0;
|
||||
}
|
||||
.ag-row-hover .ag-cell {
|
||||
background-color: transparent;
|
||||
}
|
||||
|
||||
#txt2img_enqueue_wrapper,
|
||||
#img2img_enqueue_wrapper {
|
||||
min-width: 210px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: calc(var(--layout-gap) / 2);
|
||||
}
|
||||
|
||||
#txt2img_enqueue_wrapper > div:first-child,
|
||||
#img2img_enqueue_wrapper > div:first-child {
|
||||
flex-direction: row;
|
||||
flex-wrap: nowrap;
|
||||
align-items: flex-start;
|
||||
flex: 0 0 auto;
|
||||
}
|
||||
|
||||
#txt2img_enqueue_wrapper .gradio-button,
|
||||
#img2img_enqueue_wrapper .gradio-button,
|
||||
#txt2img_enqueue_wrapper .gradio-dropdown .wrap-inner,
|
||||
#img2img_enqueue_wrapper .gradio-dropdown .wrap-inner {
|
||||
min-height: 36px;
|
||||
max-height: 42px;
|
||||
}
|
||||
|
||||
#img2img_toprow .interrogate-col.has-queue-button {
|
||||
min-width: unset !important;
|
||||
flex-direction: row !important;
|
||||
gap: calc(var(--layout-gap) / 2) !important;
|
||||
}
|
||||
|
||||
#img2img_toprow .interrogate-col.has-queue-button button {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
#agent_scheduler_current_task_progress .livePreview {
|
||||
margin: 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,470 @@
|
|||
import * as rxjs from 'rxjs';
|
||||
import type { Observer } from 'rxjs';
|
||||
import { Grid, GridOptions } from 'ag-grid-community';
|
||||
import { Notyf } from 'notyf';
|
||||
|
||||
import 'ag-grid-community/styles/ag-grid.css';
|
||||
import 'ag-grid-community/styles/ag-theme-alpine.css';
|
||||
import 'notyf/notyf.min.css';
|
||||
import './index.css';
|
||||
|
||||
declare global {
|
||||
var country: string;
|
||||
function gradioApp(): HTMLElement;
|
||||
function randomId(): string;
|
||||
function get_tab_index(name: string): number;
|
||||
function create_submit_args(args: IArguments): any[];
|
||||
function requestProgress(
|
||||
id: string,
|
||||
progressContainer: HTMLElement,
|
||||
imagesContainer: HTMLElement,
|
||||
onDone: () => void,
|
||||
): void;
|
||||
function onUiLoaded(callback: () => void): void;
|
||||
function submit_enqueue(): any[];
|
||||
function submit_enqueue_img2img(): any[];
|
||||
}
|
||||
|
||||
type Task = {
|
||||
id: string;
|
||||
api_task_id: string;
|
||||
type: string;
|
||||
status: string;
|
||||
params: Record<string, any>;
|
||||
priority: number;
|
||||
result: string;
|
||||
};
|
||||
|
||||
type AppState = {
|
||||
current_task_id: string | null;
|
||||
total_pending_tasks: number;
|
||||
pending_tasks: Task[];
|
||||
paused: boolean;
|
||||
};
|
||||
|
||||
function initTaskScheduler() {
|
||||
const notyf = new Notyf();
|
||||
const subject = new rxjs.Subject<AppState>();
|
||||
|
||||
const store = {
|
||||
subject,
|
||||
subscribe: (callback: Partial<Observer<[AppState, AppState]>>) => {
|
||||
return store.subject.pipe(rxjs.pairwise()).subscribe(callback);
|
||||
},
|
||||
refresh: async () => {
|
||||
return fetch('/agent-scheduler/v1/queue?limit=1000')
|
||||
.then((response) => response.json())
|
||||
.then((data: AppState) => {
|
||||
const pending_tasks = data.pending_tasks.map((item) => ({
|
||||
...item,
|
||||
params: JSON.parse(item.params as any),
|
||||
status: item.id === data.current_task_id ? 'running' : 'pending',
|
||||
}));
|
||||
store.subject.next({
|
||||
...data,
|
||||
pending_tasks,
|
||||
});
|
||||
});
|
||||
},
|
||||
pauseQueue: async () => {
|
||||
return fetch('/agent-scheduler/v1/pause', { method: 'POST' })
|
||||
.then((response) => response.json())
|
||||
.then((data) => {
|
||||
if (data.success) {
|
||||
notyf.success(data.message);
|
||||
} else {
|
||||
notyf.error(data.message);
|
||||
}
|
||||
|
||||
return store.refresh();
|
||||
});
|
||||
},
|
||||
resumeQueue: async () => {
|
||||
return fetch('/agent-scheduler/v1/resume', { method: 'POST' })
|
||||
.then((response) => response.json())
|
||||
.then((data) => {
|
||||
if (data.success) {
|
||||
notyf.success(data.message);
|
||||
} else {
|
||||
notyf.error(data.message);
|
||||
}
|
||||
|
||||
return store.refresh();
|
||||
});
|
||||
},
|
||||
runTask: async (id: string) => {
|
||||
return fetch(`/agent-scheduler/v1/run/${id}`, { method: 'POST' })
|
||||
.then((response) => response.json())
|
||||
.then((data) => {
|
||||
if (data.success) {
|
||||
notyf.success(data.message);
|
||||
} else {
|
||||
notyf.error(data.message);
|
||||
}
|
||||
|
||||
return store.refresh();
|
||||
});
|
||||
},
|
||||
deleteTask: async (id: string) => {
|
||||
return fetch(`/agent-scheduler/v1/delete/${id}`, { method: 'POST' })
|
||||
.then((response) => response.json())
|
||||
.then((data) => {
|
||||
if (data.success) {
|
||||
notyf.success(data.message);
|
||||
} else {
|
||||
notyf.error(data.message);
|
||||
}
|
||||
|
||||
return store.refresh();
|
||||
});
|
||||
},
|
||||
moveTask: async (id: string, overId: string) => {
|
||||
return fetch(`/agent-scheduler/v1/move/${id}/${overId}`, { method: 'POST' })
|
||||
.then((response) => response.json())
|
||||
.then((data) => {
|
||||
if (data.success) {
|
||||
notyf.success(data.message);
|
||||
} else {
|
||||
notyf.error(data.message);
|
||||
}
|
||||
|
||||
return store.refresh();
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
store.subject.next({
|
||||
current_task_id: null,
|
||||
total_pending_tasks: 0,
|
||||
pending_tasks: [],
|
||||
paused: false,
|
||||
});
|
||||
|
||||
window.submit_enqueue = function submit_enqueue() {
|
||||
var id = randomId();
|
||||
var res = create_submit_args(arguments);
|
||||
res[0] = id;
|
||||
|
||||
const btnEnqueue = document.querySelector('#txt2img_enqueue');
|
||||
if (btnEnqueue) {
|
||||
btnEnqueue.innerHTML = 'Queued';
|
||||
setTimeout(() => {
|
||||
btnEnqueue.innerHTML = 'Enqueue';
|
||||
store.refresh();
|
||||
}, 1000);
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
window.submit_enqueue_img2img = function submit_enqueue_img2img() {
|
||||
var id = randomId();
|
||||
var res = create_submit_args(arguments);
|
||||
res[0] = id;
|
||||
res[1] = get_tab_index('mode_img2img');
|
||||
|
||||
const btnEnqueue = document.querySelector('#img2img_enqueue');
|
||||
if (btnEnqueue) {
|
||||
btnEnqueue.innerHTML = 'Queued';
|
||||
setTimeout(() => {
|
||||
btnEnqueue.innerHTML = 'Enqueue';
|
||||
store.refresh();
|
||||
}, 1000);
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
// detect queue button placement
|
||||
const interrogateCol: HTMLDivElement = gradioApp().querySelector('.interrogate-col')!;
|
||||
if (interrogateCol.childElementCount > 2) {
|
||||
interrogateCol.classList.add('has-queue-button');
|
||||
}
|
||||
|
||||
// watch for tab activation
|
||||
const observer = new MutationObserver(function (mutationsList) {
|
||||
const styleChange = mutationsList.find((mutation) => mutation.attributeName === 'style');
|
||||
if (styleChange) {
|
||||
const tab = styleChange.target as HTMLElement;
|
||||
if (tab.style.display === 'block') {
|
||||
store.refresh();
|
||||
}
|
||||
}
|
||||
});
|
||||
observer.observe(document.getElementById('tab_agent_scheduler')!, { attributes: true });
|
||||
|
||||
// init actions
|
||||
const refreshButton = gradioApp().querySelector('#agent_scheduler_action_refresh')!;
|
||||
const pauseButton = gradioApp().querySelector('#agent_scheduler_action_pause')!;
|
||||
const resumeButton = gradioApp().querySelector('#agent_scheduler_action_resume')!;
|
||||
refreshButton.addEventListener('click', store.refresh);
|
||||
pauseButton.addEventListener('click', store.pauseQueue);
|
||||
resumeButton.addEventListener('click', store.resumeQueue);
|
||||
|
||||
const searchContainer = gradioApp().querySelector('#agent_scheduler_action_search')!;
|
||||
searchContainer.className = 'ts-search';
|
||||
searchContainer.innerHTML = `
|
||||
<div class="ts-search-icon">
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||
<path d="M10 10m-7 0a7 7 0 1 0 14 0a7 7 0 1 0 -14 0"/>
|
||||
<path d="M21 21l-6 -6"/>
|
||||
</svg>
|
||||
</div>
|
||||
<input type="text" id="agent_scheduler_search_input" class="ts-search-input" placeholder="Search" required>
|
||||
`;
|
||||
|
||||
// watch for current task id change
|
||||
const onTaskIdChange = (id: string | null) => {
|
||||
if (id) {
|
||||
requestProgress(
|
||||
id,
|
||||
gradioApp().querySelector('#agent_scheduler_current_task_progress')!,
|
||||
gradioApp().querySelector('#agent_scheduler_current_task_images')!,
|
||||
() => {
|
||||
setTimeout(() => {
|
||||
store.refresh();
|
||||
}, 1000);
|
||||
},
|
||||
);
|
||||
}
|
||||
};
|
||||
store.subscribe({
|
||||
next: ([prev, curr]) => {
|
||||
if (prev.current_task_id !== curr.current_task_id) {
|
||||
onTaskIdChange(curr.current_task_id);
|
||||
}
|
||||
if (curr.paused) {
|
||||
pauseButton.classList.add('hide');
|
||||
resumeButton.classList.remove('hide');
|
||||
} else {
|
||||
pauseButton.classList.remove('hide');
|
||||
resumeButton.classList.add('hide');
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// init grid
|
||||
const deleteIcon = `
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||
<path d="M4 7l16 0"/>
|
||||
<path d="M10 11l0 6"/>
|
||||
<path d="M14 11l0 6"/>
|
||||
<path d="M5 7l1 12a2 2 0 0 0 2 2h8a2 2 0 0 0 2 -2l1 -12"/>
|
||||
<path d="M9 7v-3a1 1 0 0 1 1 -1h4a1 1 0 0 1 1 1v3"/>
|
||||
</svg>`;
|
||||
const cancelIcon = `
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||
<path d="M18 6l-12 12"/>
|
||||
<path d="M6 6l12 12"/>
|
||||
</svg>
|
||||
`;
|
||||
const pendingTasksGridOptions: GridOptions<Task> = {
|
||||
// domLayout: 'autoHeight',
|
||||
// default col def properties get applied to all columns
|
||||
defaultColDef: {
|
||||
sortable: false,
|
||||
filter: true,
|
||||
resizable: true,
|
||||
suppressMenu: true,
|
||||
},
|
||||
// each entry here represents one column
|
||||
columnDefs: [
|
||||
{
|
||||
field: 'id',
|
||||
headerName: 'Task Id',
|
||||
minWidth: 240,
|
||||
maxWidth: 240,
|
||||
pinned: 'left',
|
||||
rowDrag: true,
|
||||
cellClass: ({ data }) => [
|
||||
data?.status === 'running' ? 'task-running' : '',
|
||||
],
|
||||
},
|
||||
{
|
||||
field: 'type',
|
||||
headerName: 'Type',
|
||||
minWidth: 80,
|
||||
maxWidth: 80,
|
||||
},
|
||||
{
|
||||
field: 'priority',
|
||||
headerName: 'Priority',
|
||||
hide: true,
|
||||
},
|
||||
{
|
||||
headerName: 'Params',
|
||||
children: [
|
||||
{
|
||||
field: 'params.prompt',
|
||||
headerName: 'Prompt',
|
||||
minWidth: 400,
|
||||
autoHeight: true,
|
||||
wrapText: true,
|
||||
cellStyle: { 'line-height': '24px', 'padding-top': '8px', 'padding-bottom': '8px' },
|
||||
},
|
||||
{
|
||||
field: 'params.negative_prompt',
|
||||
headerName: 'Negative Prompt',
|
||||
minWidth: 400,
|
||||
autoHeight: true,
|
||||
wrapText: true,
|
||||
cellStyle: { 'line-height': '24px', 'padding-top': '8px', 'padding-bottom': '8px' },
|
||||
},
|
||||
{
|
||||
field: 'params.checkpoint',
|
||||
headerName: 'Checkpoint',
|
||||
minWidth: 150,
|
||||
maxWidth: 300,
|
||||
valueFormatter: ({ value }) => value || 'System',
|
||||
},
|
||||
{
|
||||
field: 'params.sampler_name',
|
||||
headerName: 'Sampler',
|
||||
width: 150,
|
||||
minWidth: 150,
|
||||
},
|
||||
{
|
||||
field: 'params.steps',
|
||||
headerName: 'Steps',
|
||||
minWidth: 80,
|
||||
maxWidth: 80,
|
||||
filter: 'agNumberColumnFilter',
|
||||
},
|
||||
{
|
||||
field: 'params.cfg_scale',
|
||||
headerName: 'CFG Scale',
|
||||
width: 100,
|
||||
minWidth: 100,
|
||||
filter: 'agNumberColumnFilter',
|
||||
},
|
||||
{
|
||||
field: 'params.size',
|
||||
headerName: 'Size',
|
||||
minWidth: 110,
|
||||
maxWidth: 110,
|
||||
valueGetter: ({ data }) => (data ? `${data.params.width}x${data.params.height}` : ''),
|
||||
},
|
||||
{
|
||||
field: 'params.batch',
|
||||
headerName: 'Batching',
|
||||
minWidth: 100,
|
||||
maxWidth: 100,
|
||||
valueGetter: ({ data }) =>
|
||||
data ? `${data.params.n_iter}x${data.params.batch_size}` : '1x1',
|
||||
},
|
||||
],
|
||||
},
|
||||
{ field: 'created_at', headerName: 'Date', minWidth: 200 },
|
||||
{
|
||||
headerName: 'Action',
|
||||
pinned: 'right',
|
||||
minWidth: 110,
|
||||
maxWidth: 110,
|
||||
resizable: false,
|
||||
valueGetter: ({ data }) => data?.id,
|
||||
cellRenderer: ({ api, value, data }: any) => {
|
||||
const html = `
|
||||
<div class="inline-flex rounded-md shadow-sm mt-1.5" role="group">
|
||||
<button type="button" ${
|
||||
data.status === 'running' ? 'disabled' : ''
|
||||
} class="ts-btn-action ts-btn-run">
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||
<path d="M7 4v16l13 -8z"/>
|
||||
</svg>
|
||||
</button>
|
||||
<button type="button" class="ts-btn-action ts-btn-delete">
|
||||
${data.status === 'pending' ? deleteIcon : cancelIcon}
|
||||
</button>
|
||||
</div>
|
||||
`;
|
||||
|
||||
const placeholder = document.createElement('div');
|
||||
placeholder.innerHTML = html;
|
||||
const node = placeholder.firstElementChild!;
|
||||
|
||||
const btnRun = node.querySelector('button.ts-btn-run')!;
|
||||
btnRun.addEventListener('click', () => {
|
||||
console.log('run', value);
|
||||
api.showLoadingOverlay();
|
||||
store.runTask(value).then(() => api.hideOverlay());
|
||||
});
|
||||
|
||||
const btnDelete = node.querySelector('button.ts-btn-delete')!;
|
||||
btnDelete.addEventListener('click', () => {
|
||||
console.log('delete', value);
|
||||
api.showLoadingOverlay();
|
||||
store.deleteTask(value).then(() => api.hideOverlay());
|
||||
});
|
||||
|
||||
return node;
|
||||
},
|
||||
},
|
||||
],
|
||||
getRowId: ({ data }) => data.id,
|
||||
|
||||
rowData: [],
|
||||
rowSelection: 'single', // allow rows to be selected
|
||||
animateRows: true, // have rows animate to new positions when sorted
|
||||
pagination: true,
|
||||
paginationPageSize: 10,
|
||||
suppressCopyRowsToClipboard: true,
|
||||
suppressRowTransform: true,
|
||||
suppressRowClickSelection: true,
|
||||
enableBrowserTooltips: true,
|
||||
onGridReady: ({ api }) => {
|
||||
// init quick search input
|
||||
const searchInput: HTMLInputElement = searchContainer.querySelector(
|
||||
'input#agent_scheduler_search_input',
|
||||
)!;
|
||||
rxjs
|
||||
.fromEvent(searchInput, 'keyup')
|
||||
.pipe(rxjs.debounce(() => rxjs.interval(200)))
|
||||
.subscribe((e) => {
|
||||
api.setQuickFilter((e.target as HTMLInputElement).value);
|
||||
});
|
||||
|
||||
store.subscribe({
|
||||
next: ([_, newState]) => {
|
||||
api.setRowData(newState.pending_tasks);
|
||||
|
||||
if (newState.current_task_id) {
|
||||
const node = api.getRowNode(newState.current_task_id);
|
||||
if (node) {
|
||||
api.refreshCells({ rowNodes: [node], force: true });
|
||||
}
|
||||
}
|
||||
|
||||
api.sizeColumnsToFit();
|
||||
},
|
||||
});
|
||||
|
||||
// refresh the state
|
||||
store.refresh();
|
||||
},
|
||||
onRowDragEnd: ({ api, node, overNode }) => {
|
||||
const id = node.data?.id;
|
||||
const overId = overNode?.data?.id;
|
||||
if (id && overId && id !== overId) {
|
||||
api.showLoadingOverlay();
|
||||
store.moveTask(id, overId).then(() => api.hideOverlay());
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
const eGridDiv = gradioApp().querySelector<HTMLDivElement>(
|
||||
'#agent_scheduler_pending_tasks_grid',
|
||||
)!;
|
||||
if (document.querySelector('.dark')) {
|
||||
eGridDiv.className = 'ag-theme-alpine-dark';
|
||||
}
|
||||
eGridDiv.style.height = window.innerHeight - 240 + 'px';
|
||||
new Grid(eGridDiv, pendingTasksGridOptions);
|
||||
|
||||
store.refresh();
|
||||
}
|
||||
|
||||
onUiLoaded(initTaskScheduler);
|
||||
|
|
@ -0,0 +1 @@
|
|||
/// <reference types="vite/client" />
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
module.exports = {
|
||||
content: ["./src/**/*.{js,ts,tsx}"],
|
||||
darkMode: 'class',
|
||||
theme: {
|
||||
extend: {},
|
||||
},
|
||||
plugins: [],
|
||||
}
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"useDefineForClassFields": true,
|
||||
"lib": ["ES2020", "DOM", "DOM.Iterable"],
|
||||
"module": "ESNext",
|
||||
"skipLibCheck": true,
|
||||
|
||||
/* Bundler mode */
|
||||
"moduleResolution": "bundler",
|
||||
"allowImportingTsExtensions": true,
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"noEmit": true,
|
||||
"jsx": "react-jsx",
|
||||
|
||||
/* Linting */
|
||||
"strict": true,
|
||||
"noUnusedLocals": true,
|
||||
"noUnusedParameters": true,
|
||||
"noFallthroughCasesInSwitch": true
|
||||
},
|
||||
"include": ["src"],
|
||||
"references": [{ "path": "./tsconfig.node.json" }]
|
||||
}
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"composite": true,
|
||||
"skipLibCheck": true,
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "bundler",
|
||||
"allowSyntheticDefaultImports": true
|
||||
},
|
||||
"include": ["vite.config.ts"]
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
import { defineConfig } from 'vite';
|
||||
import react from '@vitejs/plugin-react';
|
||||
|
||||
// https://vitejs.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
build: {
|
||||
outDir: '../',
|
||||
rollupOptions: {
|
||||
input: {
|
||||
main: 'index.html',
|
||||
},
|
||||
},
|
||||
lib: {
|
||||
name: 'agent-scheduler',
|
||||
entry: 'src/extension/index.ts',
|
||||
fileName: 'javascript/extension',
|
||||
formats: ['es']
|
||||
},
|
||||
},
|
||||
});
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
import { defineConfig } from 'vite';
|
||||
|
||||
// https://vitejs.dev/config/
|
||||
export default defineConfig({
|
||||
build: {
|
||||
outDir: '../',
|
||||
copyPublicDir: false,
|
||||
lib: {
|
||||
name: 'agent-scheduler',
|
||||
entry: 'src/extension/index.ts',
|
||||
fileName: 'javascript/extension',
|
||||
formats: ['es']
|
||||
},
|
||||
},
|
||||
});
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue