chore: feature updates
- add a flag to enable/disable queue auto processing - add queue button placement setting - add a flag to hide the custom checkpoint select - rewrite frontend code in typescript - extract serialization logic to task_helpers - bugs fixingpull/7/head
parent
21d150a0bd
commit
4f7339468c
|
|
@ -36,5 +36,4 @@ notification.mp3
|
||||||
*.sql
|
*.sql
|
||||||
*.db
|
*.db
|
||||||
*.sqlite
|
*.sqlite
|
||||||
*.sqlite3
|
*.sqlite3
|
||||||
tailwind.*
|
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -5,7 +5,7 @@ from gradio.routes import App
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import progress, script_callbacks, sd_samplers
|
from modules import progress, script_callbacks, sd_samplers
|
||||||
|
|
||||||
from scripts.db import TaskStatus, AppStateKey, task_manager, state_manager
|
from scripts.db import TaskStatus, task_manager
|
||||||
from scripts.models import QueueStatusResponse
|
from scripts.models import QueueStatusResponse
|
||||||
from scripts.task_runner import TaskRunner, get_instance
|
from scripts.task_runner import TaskRunner, get_instance
|
||||||
from scripts.helpers import log
|
from scripts.helpers import log
|
||||||
|
|
@ -29,8 +29,8 @@ def regsiter_apis(app: App):
|
||||||
task_args = TaskRunner.instance.parse_task_args(
|
task_args = TaskRunner.instance.parse_task_args(
|
||||||
task.params, task.script_params, deserialization=False
|
task.params, task.script_params, deserialization=False
|
||||||
)
|
)
|
||||||
named_args = task_args["named_args"]
|
named_args = task_args.named_args
|
||||||
named_args["checkpoint"] = task_args["checkpoint"]
|
named_args["checkpoint"] = task_args.checkpoint
|
||||||
sampler_index = named_args.get("sampler_index", None)
|
sampler_index = named_args.get("sampler_index", None)
|
||||||
if sampler_index is not None:
|
if sampler_index is not None:
|
||||||
named_args["sampler_name"] = sd_samplers.samplers[
|
named_args["sampler_name"] = sd_samplers.samplers[
|
||||||
|
|
@ -103,21 +103,24 @@ def regsiter_apis(app: App):
|
||||||
|
|
||||||
@app.post("/agent-scheduler/v1/pause")
|
@app.post("/agent-scheduler/v1/pause")
|
||||||
def pause_queue():
|
def pause_queue():
|
||||||
state_manager.set_value(AppStateKey.QueueState, "paused")
|
# state_manager.set_value(AppStateKey.QueueState, "paused")
|
||||||
|
shared.opts.queue_paused = True
|
||||||
return {"success": True, "message": f"Queue is paused"}
|
return {"success": True, "message": f"Queue is paused"}
|
||||||
|
|
||||||
@app.post("/agent-scheduler/v1/resume")
|
@app.post("/agent-scheduler/v1/resume")
|
||||||
def resume_queue():
|
def resume_queue():
|
||||||
state_manager.set_value(AppStateKey.QueueState, "running")
|
# state_manager.set_value(AppStateKey.QueueState, "running")
|
||||||
|
shared.opts.queue_paused = False
|
||||||
TaskRunner.instance.execute_pending_tasks_threading()
|
TaskRunner.instance.execute_pending_tasks_threading()
|
||||||
return {"success": True, "message": f"Queue is resumed"}
|
return {"success": True, "message": f"Queue is resumed"}
|
||||||
|
|
||||||
|
|
||||||
def on_app_started(block, app: App):
|
def on_app_started(block, app: App):
|
||||||
global task_runner
|
if block is not None:
|
||||||
task_runner = get_instance(block)
|
global task_runner
|
||||||
|
task_runner = get_instance(block)
|
||||||
|
|
||||||
regsiter_apis(app)
|
regsiter_apis(app)
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_app_started(on_app_started)
|
script_callbacks.on_app_started(on_app_started)
|
||||||
|
|
|
||||||
|
|
@ -27,11 +27,15 @@ def init():
|
||||||
|
|
||||||
inspector = inspect(engine)
|
inspector = inspect(engine)
|
||||||
with engine.connect() as conn:
|
with engine.connect() as conn:
|
||||||
# check if table task has column result and add it if not
|
|
||||||
task_columns = inspector.get_columns("task")
|
task_columns = inspector.get_columns("task")
|
||||||
|
# add result column
|
||||||
if not any(col["name"] == "result" for col in task_columns):
|
if not any(col["name"] == "result" for col in task_columns):
|
||||||
conn.execute(text("ALTER TABLE task ADD COLUMN result TEXT"))
|
conn.execute(text("ALTER TABLE task ADD COLUMN result TEXT"))
|
||||||
|
|
||||||
|
# add api_task_id column
|
||||||
|
if not any(col["name"] == "api_task_id" for col in task_columns):
|
||||||
|
conn.execute(text("ALTER TABLE task ADD COLUMN api_task_id VARCHAR(64)"))
|
||||||
|
|
||||||
params_column = next(col for col in task_columns if col["name"] == "params")
|
params_column = next(col for col in task_columns if col["name"] == "params")
|
||||||
if version > "1" and not isinstance(params_column["type"], Text):
|
if version > "1" and not isinstance(params_column["type"], Text):
|
||||||
transaction = conn.begin()
|
transaction = conn.begin()
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ class Task(TaskModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
id: str = "",
|
id: str = "",
|
||||||
|
api_task_id: str = None,
|
||||||
type: str = "unknown",
|
type: str = "unknown",
|
||||||
params: str = "",
|
params: str = "",
|
||||||
script_params: bytes = b"",
|
script_params: bytes = b"",
|
||||||
|
|
@ -35,6 +36,7 @@ class Task(TaskModel):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
id=id,
|
id=id,
|
||||||
|
api_task_id=api_task_id,
|
||||||
type=type,
|
type=type,
|
||||||
params=params,
|
params=params,
|
||||||
status=status,
|
status=status,
|
||||||
|
|
@ -44,6 +46,7 @@ class Task(TaskModel):
|
||||||
updated_at=created_at,
|
updated_at=created_at,
|
||||||
)
|
)
|
||||||
self.id: str = id
|
self.id: str = id
|
||||||
|
self.api_task_id: str = api_task_id
|
||||||
self.type: str = type
|
self.type: str = type
|
||||||
self.params: str = params
|
self.params: str = params
|
||||||
self.script_params: bytes = script_params
|
self.script_params: bytes = script_params
|
||||||
|
|
@ -60,6 +63,7 @@ class Task(TaskModel):
|
||||||
def from_table(table: "TaskTable"):
|
def from_table(table: "TaskTable"):
|
||||||
return Task(
|
return Task(
|
||||||
id=table.id,
|
id=table.id,
|
||||||
|
api_task_id=table.api_task_id,
|
||||||
type=table.type,
|
type=table.type,
|
||||||
params=table.params,
|
params=table.params,
|
||||||
script_params=table.script_params,
|
script_params=table.script_params,
|
||||||
|
|
@ -72,6 +76,7 @@ class Task(TaskModel):
|
||||||
def to_table(self):
|
def to_table(self):
|
||||||
return TaskTable(
|
return TaskTable(
|
||||||
id=self.id,
|
id=self.id,
|
||||||
|
api_task_id=self.api_task_id,
|
||||||
type=self.type,
|
type=self.type,
|
||||||
params=self.params,
|
params=self.params,
|
||||||
script_params=self.script_params,
|
script_params=self.script_params,
|
||||||
|
|
@ -84,6 +89,7 @@ class TaskTable(Base):
|
||||||
__tablename__ = "task"
|
__tablename__ = "task"
|
||||||
|
|
||||||
id = Column(String(64), primary_key=True)
|
id = Column(String(64), primary_key=True)
|
||||||
|
api_task_id = Column(String(64), nullable=True)
|
||||||
type = Column(String(20), nullable=False) # txt2img or img2txt
|
type = Column(String(20), nullable=False) # txt2img or img2txt
|
||||||
params = Column(Text, nullable=False) # task args
|
params = Column(Text, nullable=False) # task args
|
||||||
script_params = Column(LargeBinary, nullable=False) # script args
|
script_params = Column(LargeBinary, nullable=False) # script args
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ class QueueStatusAPI(BaseModel):
|
||||||
|
|
||||||
class TaskModel(BaseModel):
|
class TaskModel(BaseModel):
|
||||||
id: str = Field(title="Task Id")
|
id: str = Field(title="Task Id")
|
||||||
|
api_task_id: Optional[str] = Field(title="API Task Id", default=None)
|
||||||
type: str = Field(title="Task Type", description="Either txt2img or img2img")
|
type: str = Field(title="Task Type", description="Either txt2img or img2img")
|
||||||
status: str = Field(title="Task Status", description="Either pending, running, done or failed")
|
status: str = Field(title="Task Status", description="Either pending, running, done or failed")
|
||||||
params: str = Field(title="Task Parameters", description="The parameters of the task in JSON format")
|
params: str = Field(title="Task Parameters", description="The parameters of the task in JSON format")
|
||||||
|
|
@ -34,6 +35,7 @@ class TaskModel(BaseModel):
|
||||||
datetime: convert_datetime_to_iso_8601_with_z_suffix
|
datetime: convert_datetime_to_iso_8601_with_z_suffix
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class QueueStatusResponse(BaseModel):
|
class QueueStatusResponse(BaseModel):
|
||||||
current_task_id: Optional[str] = Field(title="Current Task Id", description="The on progress task id")
|
current_task_id: Optional[str] = Field(title="Current Task Id", description="The on progress task id")
|
||||||
pending_tasks: List[TaskModel] = Field(title="Pending Tasks", description="The pending tasks in the queue")
|
pending_tasks: List[TaskModel] = Field(title="Pending Tasks", description="The pending tasks in the queue")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,441 @@
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import zlib
|
||||||
|
import base64
|
||||||
|
import inspect
|
||||||
|
import requests
|
||||||
|
import numpy as np
|
||||||
|
from enum import Enum
|
||||||
|
from PIL import Image, ImageOps, ImageChops, ImageEnhance, ImageFilter
|
||||||
|
from typing import Optional, List
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from modules import sd_samplers, scripts
|
||||||
|
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||||
|
from modules.sd_models import CheckpointInfo, get_closet_checkpoint_match
|
||||||
|
from modules.txt2img import txt2img
|
||||||
|
from modules.img2img import img2img
|
||||||
|
from modules.api.models import (
|
||||||
|
StableDiffusionTxt2ImgProcessingAPI,
|
||||||
|
StableDiffusionImg2ImgProcessingAPI,
|
||||||
|
)
|
||||||
|
|
||||||
|
from scripts.helpers import log
|
||||||
|
|
||||||
|
img2img_image_args_by_mode: dict[int, list[list[str]]] = {
|
||||||
|
0: [["init_img"]],
|
||||||
|
1: [["sketch"]],
|
||||||
|
2: [["init_img_with_mask", "image"], ["init_img_with_mask", "mask"]],
|
||||||
|
3: [["inpaint_color_sketch"], ["inpaint_color_sketch_orig"]],
|
||||||
|
4: [["init_img_inpaint"], ["init_mask_inpaint"]],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetImage(BaseModel):
|
||||||
|
image: str # base64 or url
|
||||||
|
mask: Optional[str] = None # base64 or url
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetUnit(BaseModel):
|
||||||
|
enabled: Optional[bool] = True
|
||||||
|
module: Optional[str] = "none"
|
||||||
|
model: Optional[str] = None
|
||||||
|
image: ControlNetImage = None
|
||||||
|
weight: Optional[float] = 1.0
|
||||||
|
resize_mode: Optional[str] = None
|
||||||
|
low_vram: Optional[bool] = False
|
||||||
|
processor_res: Optional[int] = 512
|
||||||
|
threshold_a: Optional[float] = 64
|
||||||
|
threshold_b: Optional[float] = 64
|
||||||
|
guidance_start: Optional[float] = 0.0
|
||||||
|
guidance_end: Optional[float] = 1.0
|
||||||
|
pixel_perfect: Optional[bool] = False
|
||||||
|
control_mode: Optional[str] = "Balanced"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseApiTaskArgs(BaseModel):
|
||||||
|
task_id: str = Field(exclude=True)
|
||||||
|
model_hash: str = Field(exclude=True)
|
||||||
|
prompt: Optional[str] = ""
|
||||||
|
styles: Optional[List[str]] = []
|
||||||
|
negative_prompt: Optional[str] = ""
|
||||||
|
seed: Optional[int] = -1
|
||||||
|
subseed: Optional[int] = 1
|
||||||
|
subseed_strength: Optional[int] = 0
|
||||||
|
seed_resize_from_h: Optional[int] = -1
|
||||||
|
seed_resize_from_w: Optional[int] = -1
|
||||||
|
sampler_name: Optional[str] = "DPM++ 2M Karras"
|
||||||
|
n_iter: Optional[int] = 1
|
||||||
|
batch_size: Optional[int] = 1
|
||||||
|
steps: Optional[int] = 20
|
||||||
|
cfg_scale: Optional[int] = 7.0
|
||||||
|
restore_faces: Optional[bool] = False
|
||||||
|
tiling: Optional[bool] = False
|
||||||
|
width: Optional[int] = 512
|
||||||
|
height: Optional[int] = 512
|
||||||
|
script_name: Optional[str] = None
|
||||||
|
controlnet_args: Optional[List[ControlNetUnit]] = Field(exclude=True, default=[])
|
||||||
|
override_settings: Optional[dict] = Field(default={})
|
||||||
|
|
||||||
|
|
||||||
|
class Txt2ImgApiTaskArgs(BaseApiTaskArgs):
|
||||||
|
enable_hr: Optional[bool] = False
|
||||||
|
denoising_strength: Optional[int] = 0
|
||||||
|
hr_scale: Optional[int] = 1
|
||||||
|
hr_upscaler: Optional[str] = "Latent"
|
||||||
|
hr_second_pass_steps: Optional[int] = 0
|
||||||
|
hr_resize_x: Optional[int] = 0
|
||||||
|
hr_resize_y: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class Img2ImgApiTaskArgs(BaseApiTaskArgs):
|
||||||
|
init_images: List[str]
|
||||||
|
mask: Optional[str] = None
|
||||||
|
resize_mode: Optional[int] = 0
|
||||||
|
denoising_strength: Optional[int] = 0.75
|
||||||
|
mask_blur: Optional[int] = 4
|
||||||
|
inpainting_fill: Optional[int] = 0
|
||||||
|
inpaint_full_res: Optional[bool] = True
|
||||||
|
inpaint_full_res_padding: Optional[int] = 0
|
||||||
|
inpainting_mask_invert: Optional[int] = 0
|
||||||
|
initial_noise_multiplier: Optional[float] = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def load_image_from_url(url: str):
|
||||||
|
try:
|
||||||
|
response = requests.get(url)
|
||||||
|
buffer = io.BytesIO(response.content)
|
||||||
|
return Image.open(buffer)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"[AgentScheduler] Error downloading image from url: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(image: str):
|
||||||
|
if not isinstance(image, str):
|
||||||
|
return image
|
||||||
|
|
||||||
|
pil_image = None
|
||||||
|
if os.path.exists(image):
|
||||||
|
pil_image = Image.open(image)
|
||||||
|
elif image.startswith(("http://", "https://")):
|
||||||
|
pil_image = load_image_from_url(image)
|
||||||
|
|
||||||
|
return pil_image
|
||||||
|
|
||||||
|
|
||||||
|
def load_image_to_nparray(image: str):
|
||||||
|
pil_image = load_image(image)
|
||||||
|
|
||||||
|
return (
|
||||||
|
np.array(pil_image).astype("uint8")
|
||||||
|
if isinstance(pil_image, Image.Image)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_pil_to_base64(image: Image.Image):
|
||||||
|
with io.BytesIO() as output_bytes:
|
||||||
|
image.save(output_bytes, format="PNG")
|
||||||
|
bytes_data = output_bytes.getvalue()
|
||||||
|
return base64.b64encode(bytes_data).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def load_image_to_base64(image: str):
|
||||||
|
pil_image = load_image(image)
|
||||||
|
|
||||||
|
if not isinstance(pil_image, Image.Image):
|
||||||
|
return image
|
||||||
|
|
||||||
|
return encode_pil_to_base64(pil_image)
|
||||||
|
|
||||||
|
|
||||||
|
def __serialize_image(image):
|
||||||
|
if isinstance(image, np.ndarray):
|
||||||
|
shape = image.shape
|
||||||
|
data = base64.b64encode(zlib.compress(image.tobytes())).decode()
|
||||||
|
return {"shape": shape, "data": data, "cls": "ndarray"}
|
||||||
|
elif isinstance(image, Image.Image):
|
||||||
|
size = image.size
|
||||||
|
mode = image.mode
|
||||||
|
data = base64.b64encode(zlib.compress(image.tobytes())).decode()
|
||||||
|
return {
|
||||||
|
"size": size,
|
||||||
|
"mode": mode,
|
||||||
|
"data": data,
|
||||||
|
"cls": "Image",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def __deserialize_image(image_str):
|
||||||
|
if isinstance(image_str, dict) and image_str.get("cls", None):
|
||||||
|
cls = image_str["cls"]
|
||||||
|
data = zlib.decompress(base64.b64decode(image_str["data"]))
|
||||||
|
|
||||||
|
if cls == "ndarray":
|
||||||
|
shape = tuple(image_str["shape"])
|
||||||
|
image = np.frombuffer(data, dtype=np.uint8)
|
||||||
|
return image.reshape(shape)
|
||||||
|
else:
|
||||||
|
size = tuple(image_str["size"])
|
||||||
|
mode = image_str["mode"]
|
||||||
|
return Image.frombytes(mode, size, data)
|
||||||
|
else:
|
||||||
|
return image_str
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_img2img_image_args(args: dict):
|
||||||
|
for mode, image_args in img2img_image_args_by_mode.items():
|
||||||
|
for keys in image_args:
|
||||||
|
if mode != args["mode"]:
|
||||||
|
# set None to unused image args to save space
|
||||||
|
args[keys[0]] = None
|
||||||
|
elif len(keys) == 1:
|
||||||
|
image = args.get(keys[0], None)
|
||||||
|
args[keys[0]] = __serialize_image(image)
|
||||||
|
else:
|
||||||
|
value = args.get(keys[0], {})
|
||||||
|
image = value.get(keys[1], None)
|
||||||
|
value[keys[1]] = __serialize_image(image)
|
||||||
|
args[keys[0]] = value
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_img2img_image_args(args: dict):
|
||||||
|
for mode, image_args in img2img_image_args_by_mode.items():
|
||||||
|
if mode != args["mode"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for keys in image_args:
|
||||||
|
if len(keys) == 1:
|
||||||
|
image = args.get(keys[0], None)
|
||||||
|
args[keys[0]] = __deserialize_image(image)
|
||||||
|
else:
|
||||||
|
value = args.get(keys[0], {})
|
||||||
|
image = value.get(keys[1], None)
|
||||||
|
value[keys[1]] = __deserialize_image(image)
|
||||||
|
args[keys[0]] = value
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_controlnet_args(cnet_unit):
|
||||||
|
args: dict = cnet_unit.__dict__
|
||||||
|
args["is_cnet"] = True
|
||||||
|
for k, v in args.items():
|
||||||
|
if k == "image" and v is not None:
|
||||||
|
args[k] = {
|
||||||
|
"image": __serialize_image(v["image"]),
|
||||||
|
"mask": __serialize_image(v["mask"])
|
||||||
|
if v.get("mask", None) is not None
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
if isinstance(v, Enum):
|
||||||
|
args[k] = v.value
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_controlnet_args(args: dict):
|
||||||
|
# args.pop("is_cnet", None)
|
||||||
|
for k, v in args.items():
|
||||||
|
if k == "image" and v is not None:
|
||||||
|
args[k] = {
|
||||||
|
"image": __deserialize_image(v["image"]),
|
||||||
|
"mask": __deserialize_image(v["mask"])
|
||||||
|
if v.get("mask", None) is not None
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def map_ui_task_args_list_to_named_args(
|
||||||
|
args: list, is_img2img: bool, checkpoint: str = None
|
||||||
|
):
|
||||||
|
args_name = []
|
||||||
|
if is_img2img:
|
||||||
|
args_name = inspect.getfullargspec(img2img).args
|
||||||
|
else:
|
||||||
|
args_name = inspect.getfullargspec(txt2img).args
|
||||||
|
|
||||||
|
named_args = dict(zip(args_name, args[0 : len(args_name)]))
|
||||||
|
script_args = args[len(args_name) :]
|
||||||
|
if checkpoint is not None:
|
||||||
|
override_settings_texts = named_args.get("override_settings_texts", [])
|
||||||
|
override_settings_texts.append("Model hash: " + checkpoint)
|
||||||
|
named_args["override_settings_texts"] = override_settings_texts
|
||||||
|
|
||||||
|
return (
|
||||||
|
named_args,
|
||||||
|
script_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def map_ui_task_args_to_api_task_args(
|
||||||
|
named_args: dict, script_args: list, is_img2img: bool
|
||||||
|
):
|
||||||
|
api_task_args: dict = named_args.copy()
|
||||||
|
|
||||||
|
prompt_styles = api_task_args.pop("prompt_styles", [])
|
||||||
|
api_task_args["styles"] = prompt_styles
|
||||||
|
|
||||||
|
sampler_index = api_task_args.pop("sampler_index", 0)
|
||||||
|
api_task_args["sampler_name"] = sd_samplers.samplers[sampler_index].name
|
||||||
|
|
||||||
|
override_settings_texts = api_task_args.pop("override_settings_texts", [])
|
||||||
|
api_task_args["override_settings"] = create_override_settings_dict(
|
||||||
|
override_settings_texts
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_img2img:
|
||||||
|
mode = api_task_args.pop("mode", 0)
|
||||||
|
for arg_mode, image_args in img2img_image_args_by_mode.items():
|
||||||
|
if mode != arg_mode:
|
||||||
|
for keys in image_args:
|
||||||
|
api_task_args.pop(keys[0], None)
|
||||||
|
|
||||||
|
# the logic below is copied from modules/img2img.py
|
||||||
|
if mode == 0:
|
||||||
|
image = api_task_args.pop("init_img").convert("RGB")
|
||||||
|
mask = None
|
||||||
|
elif mode == 1:
|
||||||
|
image = api_task_args.pop("sketch").convert("RGB")
|
||||||
|
mask = None
|
||||||
|
elif mode == 2:
|
||||||
|
init_img_with_mask: dict = api_task_args.pop("init_img_with_mask")
|
||||||
|
image = init_img_with_mask.get("image").convert("RGB")
|
||||||
|
mask = init_img_with_mask.get("mask")
|
||||||
|
alpha_mask = (
|
||||||
|
ImageOps.invert(image.split()[-1])
|
||||||
|
.convert("L")
|
||||||
|
.point(lambda x: 255 if x > 0 else 0, mode="1")
|
||||||
|
)
|
||||||
|
mask = ImageChops.lighter(alpha_mask, mask.convert("L")).convert("L")
|
||||||
|
elif mode == 3:
|
||||||
|
image = api_task_args.pop("inpaint_color_sketch")
|
||||||
|
orig = api_task_args.pop("inpaint_color_sketch_orig") or image
|
||||||
|
mask_alpha = api_task_args.pop("mask_alpha", 0)
|
||||||
|
mask_blur = api_task_args.get("mask_blur", 4)
|
||||||
|
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
||||||
|
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
||||||
|
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
||||||
|
blur = ImageFilter.GaussianBlur(mask_blur)
|
||||||
|
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||||
|
image = image.convert("RGB")
|
||||||
|
elif mode == 4:
|
||||||
|
image = api_task_args.pop("init_img_inpaint")
|
||||||
|
mask = api_task_args.pop("init_mask_inpaint")
|
||||||
|
else:
|
||||||
|
raise Exception(f"Batch mode is not supported yet")
|
||||||
|
|
||||||
|
image = ImageOps.exif_transpose(image)
|
||||||
|
api_task_args["init_images"] = [encode_pil_to_base64(image)]
|
||||||
|
api_task_args["mask"] = encode_pil_to_base64(mask) if mask is not None else None
|
||||||
|
|
||||||
|
selected_scale_tab = api_task_args.pop("selected_scale_tab", 0)
|
||||||
|
scale_by = api_task_args.pop("scale_by", 1)
|
||||||
|
if selected_scale_tab == 1:
|
||||||
|
api_task_args["width"] = int(image.width * scale_by)
|
||||||
|
api_task_args["height"] = int(image.height * scale_by)
|
||||||
|
else:
|
||||||
|
hr_sampler_index = api_task_args.pop("hr_sampler_index", 0)
|
||||||
|
api_task_args["hr_sampler_name"] = (
|
||||||
|
sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name
|
||||||
|
if hr_sampler_index != 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# script
|
||||||
|
script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
|
||||||
|
script_id = script_args[0]
|
||||||
|
if script_id == 0:
|
||||||
|
api_task_args["script_name"] = None
|
||||||
|
api_task_args["script_args"] = []
|
||||||
|
else:
|
||||||
|
script = script_runner.selectable_scripts[script_id - 1]
|
||||||
|
api_task_args["script_name"] = script.title.lower()
|
||||||
|
api_task_args["script_args"] = script_args[script.args_from : script.args_to]
|
||||||
|
|
||||||
|
# alwayson scripts
|
||||||
|
alwayson_scripts = api_task_args.get("alwayson_scripts", None)
|
||||||
|
if not alwayson_scripts or not isinstance(alwayson_scripts, dict):
|
||||||
|
alwayson_scripts = {}
|
||||||
|
api_task_args["alwayson_scripts"] = alwayson_scripts
|
||||||
|
|
||||||
|
for script in script_runner.alwayson_scripts:
|
||||||
|
alwayson_script_args = script_args[script.args_from : script.args_to]
|
||||||
|
if script.title.lower() == "controlnet":
|
||||||
|
for i, cnet_args in enumerate(alwayson_script_args):
|
||||||
|
alwayson_script_args[i] = serialize_controlnet_args(cnet_args)
|
||||||
|
|
||||||
|
alwayson_scripts[script.title.lower()] = {"args": alwayson_script_args}
|
||||||
|
|
||||||
|
return api_task_args
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_api_task_args(params: dict, is_img2img: bool):
|
||||||
|
# pop out custom params
|
||||||
|
model_hash = params.pop("model_hash", None)
|
||||||
|
controlnet_args = params.pop("controlnet_args", None)
|
||||||
|
|
||||||
|
args = (
|
||||||
|
StableDiffusionImg2ImgProcessingAPI(**params)
|
||||||
|
if is_img2img
|
||||||
|
else StableDiffusionTxt2ImgProcessingAPI(**params)
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.override_settings is None:
|
||||||
|
args.override_settings = {}
|
||||||
|
|
||||||
|
if model_hash is None:
|
||||||
|
model_hash = args.override_settings.get("sd_model_checkpoint", None)
|
||||||
|
|
||||||
|
if model_hash is None:
|
||||||
|
log.error("[AgentScheduler] API task must supply model hash")
|
||||||
|
return
|
||||||
|
|
||||||
|
checkpoint: CheckpointInfo = get_closet_checkpoint_match(model_hash)
|
||||||
|
if not checkpoint:
|
||||||
|
log.warn(f"[AgentScheduler] No checkpoint found for model hash {model_hash}")
|
||||||
|
return
|
||||||
|
args.override_settings["sd_model_checkpoint"] = checkpoint.title
|
||||||
|
|
||||||
|
# load images from url or file if needed
|
||||||
|
if is_img2img:
|
||||||
|
init_images = args.init_images
|
||||||
|
for i, image in enumerate(init_images):
|
||||||
|
init_images[i] = load_image_to_base64(image)
|
||||||
|
|
||||||
|
args.mask = load_image_to_base64(args.mask)
|
||||||
|
|
||||||
|
# handle custom controlnet args
|
||||||
|
if controlnet_args is not None:
|
||||||
|
if args.alwayson_scripts is None:
|
||||||
|
args.alwayson_scripts = {}
|
||||||
|
|
||||||
|
controlnets = []
|
||||||
|
for cnet in controlnet_args:
|
||||||
|
enabled = cnet.get("enabled", True)
|
||||||
|
cnet_image = cnet.get("image", None)
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
continue
|
||||||
|
if not isinstance(cnet_image, dict):
|
||||||
|
log.error(f"[AgentScheduler] Controlnet image is required")
|
||||||
|
continue
|
||||||
|
|
||||||
|
image = cnet_image.get("image", None)
|
||||||
|
mask = cnet_image.get("mask", None)
|
||||||
|
if image is None:
|
||||||
|
log.error(f"[AgentScheduler] Controlnet image is required")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# load controlnet images from url or file if needed
|
||||||
|
cnet_image["image"] = load_image_to_base64(image)
|
||||||
|
cnet_image["mask"] = load_image_to_base64(mask)
|
||||||
|
controlnets.append(cnet)
|
||||||
|
|
||||||
|
if len(controlnets) > 0:
|
||||||
|
args.alwayson_scripts["controlnet"] = {"args": controlnets}
|
||||||
|
|
||||||
|
return args.dict()
|
||||||
|
|
@ -1,18 +1,13 @@
|
||||||
import io
|
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import zlib
|
|
||||||
import base64
|
|
||||||
import pickle
|
import pickle
|
||||||
import inspect
|
import inspect
|
||||||
import traceback
|
import traceback
|
||||||
import threading
|
import threading
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from enum import Enum
|
|
||||||
from PIL import Image, PngImagePlugin
|
|
||||||
from typing import Any, Callable, Union
|
from typing import Any, Callable, Union
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
|
@ -20,22 +15,29 @@ from modules import progress, shared, script_callbacks
|
||||||
from modules.call_queue import queue_lock, wrap_gradio_call
|
from modules.call_queue import queue_lock, wrap_gradio_call
|
||||||
from modules.txt2img import txt2img
|
from modules.txt2img import txt2img
|
||||||
from modules.img2img import img2img
|
from modules.img2img import img2img
|
||||||
from modules.api.api import Api, encode_pil_to_base64
|
from modules.api.api import Api
|
||||||
from modules.api.models import (
|
from modules.api.models import (
|
||||||
StableDiffusionTxt2ImgProcessingAPI,
|
StableDiffusionTxt2ImgProcessingAPI,
|
||||||
StableDiffusionImg2ImgProcessingAPI,
|
StableDiffusionImg2ImgProcessingAPI,
|
||||||
)
|
)
|
||||||
|
|
||||||
from scripts.db import TaskStatus, AppStateKey, Task, task_manager, state_manager
|
from scripts.db import TaskStatus, Task, task_manager
|
||||||
from scripts.helpers import log, detect_control_net, get_component_by_elem_id
|
from scripts.helpers import log, detect_control_net, get_component_by_elem_id
|
||||||
|
from scripts.task_helpers import (
|
||||||
|
serialize_img2img_image_args,
|
||||||
|
deserialize_img2img_image_args,
|
||||||
|
serialize_controlnet_args,
|
||||||
|
deserialize_controlnet_args,
|
||||||
|
map_ui_task_args_list_to_named_args,
|
||||||
|
)
|
||||||
|
|
||||||
img2img_image_args_by_mode: dict[int, list[list[str]]] = {
|
|
||||||
0: [["init_img"]],
|
class ParsedTaskArgs(BaseModel):
|
||||||
1: [["sketch"]],
|
args: list[Any]
|
||||||
2: [["init_img_with_mask", "image"], ["init_img_with_mask", "mask"]],
|
named_args: dict[str, Any]
|
||||||
3: [["inpaint_color_sketch"], ["inpaint_color_sketch_orig"]],
|
script_args: list[Any]
|
||||||
4: [["init_img_inpaint"], ["init_mask_inpaint"]],
|
checkpoint: str
|
||||||
}
|
is_ui: bool
|
||||||
|
|
||||||
|
|
||||||
class TaskRunner:
|
class TaskRunner:
|
||||||
|
|
@ -75,103 +77,22 @@ class TaskRunner:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def paused(self) -> bool:
|
def paused(self) -> bool:
|
||||||
return state_manager.get_value(AppStateKey.QueueState) == "paused"
|
return shared.opts.queue_paused
|
||||||
|
|
||||||
def __serialize_image(self, image):
|
|
||||||
if isinstance(image, np.ndarray):
|
|
||||||
shape = image.shape
|
|
||||||
data = base64.b64encode(zlib.compress(image.tobytes())).decode()
|
|
||||||
return {"shape": shape, "data": data, "cls": "ndarray"}
|
|
||||||
elif isinstance(image, Image.Image):
|
|
||||||
size = image.size
|
|
||||||
mode = image.mode
|
|
||||||
data = base64.b64encode(zlib.compress(image.tobytes())).decode()
|
|
||||||
return {
|
|
||||||
"size": size,
|
|
||||||
"mode": mode,
|
|
||||||
"data": data,
|
|
||||||
"cls": "Image",
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return image
|
|
||||||
|
|
||||||
def __deserialize_image(self, image_str):
|
|
||||||
if isinstance(image_str, dict) and image_str.get("cls", None):
|
|
||||||
cls = image_str["cls"]
|
|
||||||
data = zlib.decompress(base64.b64decode(image_str["data"]))
|
|
||||||
|
|
||||||
if cls == "ndarray":
|
|
||||||
shape = tuple(image_str["shape"])
|
|
||||||
image = np.frombuffer(data, dtype=np.uint8)
|
|
||||||
return image.reshape(shape)
|
|
||||||
else:
|
|
||||||
size = tuple(image_str["size"])
|
|
||||||
mode = image_str["mode"]
|
|
||||||
return Image.frombytes(mode, size, data)
|
|
||||||
else:
|
|
||||||
return image_str
|
|
||||||
|
|
||||||
def __serialize_img2img_images(self, args: dict, image_args: list):
|
|
||||||
for keys in image_args:
|
|
||||||
if len(keys) == 1:
|
|
||||||
image = args.get(keys[0], None)
|
|
||||||
args[keys[0]] = self.__serialize_image(image)
|
|
||||||
else:
|
|
||||||
value = args.get(keys[0], {})
|
|
||||||
image = value.get(keys[1], None)
|
|
||||||
value[keys[1]] = self.__serialize_image(image)
|
|
||||||
args[keys[0]] = value
|
|
||||||
|
|
||||||
def __deserialize_img2img_images(self, args: dict, image_args: list):
|
|
||||||
for keys in image_args:
|
|
||||||
if len(keys) == 1:
|
|
||||||
image = args.get(keys[0], None)
|
|
||||||
args[keys[0]] = self.__deserialize_image(image)
|
|
||||||
else:
|
|
||||||
value = args.get(keys[0], {})
|
|
||||||
image = value.get(keys[1], None)
|
|
||||||
value[keys[1]] = self.__deserialize_image(image)
|
|
||||||
args[keys[0]] = value
|
|
||||||
|
|
||||||
def __serialize_ui_task_args(self, is_img2img: bool, *args, checkpoint: str = None):
|
def __serialize_ui_task_args(self, is_img2img: bool, *args, checkpoint: str = None):
|
||||||
args_name = []
|
named_args, script_args = map_ui_task_args_list_to_named_args(
|
||||||
if is_img2img:
|
list(args), is_img2img, checkpoint=checkpoint
|
||||||
args_name = inspect.getfullargspec(img2img).args
|
)
|
||||||
else:
|
|
||||||
args_name = inspect.getfullargspec(txt2img).args
|
|
||||||
|
|
||||||
args = list(args)
|
|
||||||
named_args = dict(zip(args_name, args[0 : len(args_name)]))
|
|
||||||
script_args = args[len(args_name) :]
|
|
||||||
if checkpoint:
|
|
||||||
override_settings_texts = named_args.get("override_settings_texts", [])
|
|
||||||
override_settings_texts.append("Model hash: " + checkpoint)
|
|
||||||
named_args["override_settings_texts"] = override_settings_texts
|
|
||||||
|
|
||||||
# loop through named_args and serialize images
|
# loop through named_args and serialize images
|
||||||
if is_img2img:
|
if is_img2img:
|
||||||
for mode, image_args in img2img_image_args_by_mode.items():
|
serialize_img2img_image_args(named_args)
|
||||||
if mode == named_args["mode"]:
|
|
||||||
self.__serialize_img2img_images(named_args, image_args)
|
|
||||||
else:
|
|
||||||
# set None to unused image args to save space
|
|
||||||
for keys in image_args:
|
|
||||||
named_args[keys[0]] = None
|
|
||||||
|
|
||||||
# loop through script_args and serialize controlnets
|
# loop through script_args and serialize controlnets
|
||||||
if self.UiControlNetUnit is not None:
|
if self.UiControlNetUnit is not None:
|
||||||
for i, a in enumerate(script_args):
|
for i, a in enumerate(script_args):
|
||||||
if isinstance(a, self.UiControlNetUnit):
|
if isinstance(a, self.UiControlNetUnit):
|
||||||
script_args[i] = a.__dict__
|
script_args[i] = serialize_controlnet_args(a)
|
||||||
script_args[i]["is_cnet"] = True
|
|
||||||
for k, v in script_args[i].items():
|
|
||||||
if k == "image" and v is not None:
|
|
||||||
script_args[i][k] = {
|
|
||||||
"image": self.__serialize_image(v["image"]),
|
|
||||||
"mask": self.__serialize_image(v["mask"]),
|
|
||||||
}
|
|
||||||
if isinstance(v, Enum):
|
|
||||||
script_args[i][k] = v.value
|
|
||||||
|
|
||||||
return json.dumps(
|
return json.dumps(
|
||||||
{
|
{
|
||||||
|
|
@ -186,6 +107,7 @@ class TaskRunner:
|
||||||
def __serialize_api_task_args(
|
def __serialize_api_task_args(
|
||||||
self, is_img2img: bool, script_args: list = [], **named_args
|
self, is_img2img: bool, script_args: list = [], **named_args
|
||||||
):
|
):
|
||||||
|
# serialization steps are done in task_helpers.register_api_task
|
||||||
override_settings = named_args.get("override_settings", {})
|
override_settings = named_args.get("override_settings", {})
|
||||||
checkpoint = override_settings.get("sd_model_checkpoint", None)
|
checkpoint = override_settings.get("sd_model_checkpoint", None)
|
||||||
|
|
||||||
|
|
@ -204,21 +126,17 @@ class TaskRunner:
|
||||||
):
|
):
|
||||||
# loop through image_args and deserialize images
|
# loop through image_args and deserialize images
|
||||||
if is_img2img:
|
if is_img2img:
|
||||||
for mode, image_args in img2img_image_args_by_mode.items():
|
deserialize_img2img_image_args(named_args)
|
||||||
if mode == named_args["mode"]:
|
|
||||||
self.__deserialize_img2img_images(named_args, image_args)
|
|
||||||
|
|
||||||
# loop through script_args and deserialize controlnets
|
# loop through script_args and deserialize controlnets
|
||||||
if self.UiControlNetUnit is not None:
|
if self.UiControlNetUnit is not None:
|
||||||
for i, arg in enumerate(script_args):
|
for i, arg in enumerate(script_args):
|
||||||
if isinstance(arg, dict) and arg.get("is_cnet", False):
|
if isinstance(arg, dict) and arg.get("is_cnet", False):
|
||||||
arg.pop("is_cnet")
|
script_args[i] = deserialize_controlnet_args(arg)
|
||||||
for k, v in arg.items():
|
|
||||||
if k == "image" and v is not None:
|
def __deserialize_api_task_args(self, is_img2img: bool, named_args: dict):
|
||||||
arg[k] = {
|
# API task use base64 images as input, no need to deserialize
|
||||||
"image": self.__deserialize_image(v["image"]),
|
pass
|
||||||
"mask": self.__deserialize_image(v["mask"]),
|
|
||||||
}
|
|
||||||
|
|
||||||
def parse_task_args(
|
def parse_task_args(
|
||||||
self, params: str, script_params: bytes, deserialization: bool = True
|
self, params: str, script_params: bytes, deserialization: bool = True
|
||||||
|
|
@ -237,16 +155,18 @@ class TaskRunner:
|
||||||
|
|
||||||
if is_ui and deserialization:
|
if is_ui and deserialization:
|
||||||
self.__deserialize_ui_task_args(is_img2img, named_args, script_args)
|
self.__deserialize_ui_task_args(is_img2img, named_args, script_args)
|
||||||
|
elif deserialization:
|
||||||
|
self.__deserialize_api_task_args(is_img2img, named_args)
|
||||||
|
|
||||||
args = list(named_args.values()) + script_args
|
args = list(named_args.values()) + script_args
|
||||||
|
|
||||||
return {
|
return ParsedTaskArgs(
|
||||||
"args": args,
|
args=args,
|
||||||
"named_args": named_args,
|
named_args=named_args,
|
||||||
"script_args": script_args,
|
script_args=script_args,
|
||||||
"checkpoint": checkpoint,
|
checkpoint=checkpoint,
|
||||||
"is_ui": is_ui,
|
is_ui=is_ui,
|
||||||
}
|
)
|
||||||
|
|
||||||
def register_ui_task(
|
def register_ui_task(
|
||||||
self, task_id: str, is_img2img: bool, *args, checkpoint: str = None
|
self, task_id: str, is_img2img: bool, *args, checkpoint: str = None
|
||||||
|
|
@ -263,13 +183,19 @@ class TaskRunner:
|
||||||
)
|
)
|
||||||
self.__total_pending_tasks += 1
|
self.__total_pending_tasks += 1
|
||||||
|
|
||||||
def register_api_task(self, task_id: str, is_img2img: bool, args: dict):
|
def register_api_task(
|
||||||
|
self, task_id: str, api_task_id: str, is_img2img: bool, args: dict
|
||||||
|
):
|
||||||
progress.add_task_to_queue(task_id)
|
progress.add_task_to_queue(task_id)
|
||||||
|
|
||||||
|
args = args.copy()
|
||||||
|
args.update({"save_images": True, "send_images": False})
|
||||||
params = self.__serialize_api_task_args(is_img2img, **args)
|
params = self.__serialize_api_task_args(is_img2img, **args)
|
||||||
|
|
||||||
task_type = "img2img" if is_img2img else "txt2img"
|
task_type = "img2img" if is_img2img else "txt2img"
|
||||||
task_manager.add_task(Task(id=task_id, type=task_type, params=params))
|
task_manager.add_task(
|
||||||
|
Task(id=task_id, api_task_id=api_task_id, type=task_type, params=params)
|
||||||
|
)
|
||||||
|
|
||||||
self.__run_callbacks(
|
self.__run_callbacks(
|
||||||
"task_registered", task_id, is_img2img=is_img2img, is_ui=False, args=params
|
"task_registered", task_id, is_img2img=is_img2img, is_ui=False, args=params
|
||||||
|
|
@ -279,7 +205,11 @@ class TaskRunner:
|
||||||
def execute_task(self, task: Task, get_next_task: Callable):
|
def execute_task(self, task: Task, get_next_task: Callable):
|
||||||
while True:
|
while True:
|
||||||
if self.dispose:
|
if self.dispose:
|
||||||
sys.exit(0)
|
break
|
||||||
|
|
||||||
|
if self.paused:
|
||||||
|
log.info("[AgentScheduler] Runner is paused")
|
||||||
|
break
|
||||||
|
|
||||||
if progress.current_task is None:
|
if progress.current_task is None:
|
||||||
task_id = task.id
|
task_id = task.id
|
||||||
|
|
@ -290,7 +220,11 @@ class TaskRunner:
|
||||||
task.params,
|
task.params,
|
||||||
task.script_params,
|
task.script_params,
|
||||||
)
|
)
|
||||||
task_meta = {"is_img2img": is_img2img, "is_ui": task_args["is_ui"]}
|
task_meta = {
|
||||||
|
"is_img2img": is_img2img,
|
||||||
|
"is_ui": task_args.is_ui,
|
||||||
|
"api_task_id": task.api_task_id,
|
||||||
|
}
|
||||||
|
|
||||||
self.__saved_images_path = []
|
self.__saved_images_path = []
|
||||||
self.__run_callbacks("task_started", task_id, **task_meta)
|
self.__run_callbacks("task_started", task_id, **task_meta)
|
||||||
|
|
@ -333,7 +267,7 @@ class TaskRunner:
|
||||||
|
|
||||||
task = get_next_task()
|
task = get_next_task()
|
||||||
if not task:
|
if not task:
|
||||||
sys.exit(0)
|
break
|
||||||
|
|
||||||
def execute_pending_tasks_threading(self):
|
def execute_pending_tasks_threading(self):
|
||||||
if self.paused:
|
if self.paused:
|
||||||
|
|
@ -366,19 +300,19 @@ class TaskRunner:
|
||||||
return [
|
return [
|
||||||
task.id,
|
task.id,
|
||||||
task.type,
|
task.type,
|
||||||
json.dumps(task_args["named_args"]),
|
json.dumps(task_args.named_args),
|
||||||
task.created_at.strftime("%Y-%m-%d %H:%M:%S"),
|
task.created_at.strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __execute_task(self, task_id: str, is_img2img: bool, task_args: dict):
|
def __execute_task(self, task_id: str, is_img2img: bool, task_args: ParsedTaskArgs):
|
||||||
if task_args["is_ui"]:
|
if task_args.is_ui:
|
||||||
return self.__execute_ui_task(task_id, is_img2img, *task_args["args"])
|
return self.__execute_ui_task(task_id, is_img2img, *task_args.args)
|
||||||
else:
|
else:
|
||||||
return self.__execute_api_task(
|
return self.__execute_api_task(
|
||||||
task_id,
|
task_id,
|
||||||
is_img2img,
|
is_img2img,
|
||||||
script_args=task_args["script_args"],
|
script_args=task_args.script_args,
|
||||||
**task_args["named_args"],
|
**task_args.named_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __execute_ui_task(self, task_id: str, is_img2img: bool, *args):
|
def __execute_ui_task(self, task_id: str, is_img2img: bool, *args):
|
||||||
|
|
@ -483,9 +417,12 @@ class TaskRunner:
|
||||||
|
|
||||||
def get_instance(block) -> TaskRunner:
|
def get_instance(block) -> TaskRunner:
|
||||||
if TaskRunner.instance is None:
|
if TaskRunner.instance is None:
|
||||||
txt2img_submit_button = get_component_by_elem_id(block, "txt2img_generate")
|
if block is not None:
|
||||||
UiControlNetUnit = detect_control_net(block, txt2img_submit_button)
|
txt2img_submit_button = get_component_by_elem_id(block, "txt2img_generate")
|
||||||
TaskRunner(UiControlNetUnit)
|
UiControlNetUnit = detect_control_net(block, txt2img_submit_button)
|
||||||
|
TaskRunner(UiControlNetUnit)
|
||||||
|
else:
|
||||||
|
TaskRunner()
|
||||||
|
|
||||||
def on_before_reload():
|
def on_before_reload():
|
||||||
# Tell old instance to stop
|
# Tell old instance to stop
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from modules.ui import create_refresh_button
|
||||||
|
|
||||||
from scripts.task_runner import TaskRunner, get_instance
|
from scripts.task_runner import TaskRunner, get_instance
|
||||||
from scripts.helpers import compare_components_with_ids, get_components_by_ids
|
from scripts.helpers import compare_components_with_ids, get_components_by_ids
|
||||||
from scripts.db import init, state_manager, AppStateKey
|
from scripts.db import init
|
||||||
|
|
||||||
task_runner: TaskRunner = None
|
task_runner: TaskRunner = None
|
||||||
initialized = False
|
initialized = False
|
||||||
|
|
@ -14,6 +14,10 @@ initialized = False
|
||||||
checkpoint_current = "Current Checkpoint"
|
checkpoint_current = "Current Checkpoint"
|
||||||
checkpoint_runtime = "Runtime Checkpoint"
|
checkpoint_runtime = "Runtime Checkpoint"
|
||||||
|
|
||||||
|
placement_above_generate = "Above Generate button"
|
||||||
|
placement_under_generate = "Under Generate button"
|
||||||
|
placement_between_prompt_and_generate = "Between Prompt and Generate button"
|
||||||
|
|
||||||
|
|
||||||
class Script(scripts.Script):
|
class Script(scripts.Script):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -77,10 +81,11 @@ class Script(scripts.Script):
|
||||||
outputs=fn_block.outputs,
|
outputs=fn_block.outputs,
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
id_part = "img2img" if is_img2img else "txt2img"
|
||||||
with root:
|
with root:
|
||||||
with generate.parent:
|
with gr.Row(elem_id=f"{id_part}_enqueue_wrapper") as row:
|
||||||
id_part = "img2img" if is_img2img else "txt2img"
|
if not shared.opts.queue_button_hide_checkpoint:
|
||||||
with gr.Row(elem_id=f"{id_part}_enqueue_wrapper"):
|
|
||||||
checkpoint = gr.Dropdown(
|
checkpoint = gr.Dropdown(
|
||||||
choices=get_checkpoint_choices(),
|
choices=get_checkpoint_choices(),
|
||||||
value=checkpoint_current,
|
value=checkpoint_current,
|
||||||
|
|
@ -93,13 +98,36 @@ class Script(scripts.Script):
|
||||||
lambda: {"choices": get_checkpoint_choices()},
|
lambda: {"choices": get_checkpoint_choices()},
|
||||||
f"refresh_{id_part}_checkpoint",
|
f"refresh_{id_part}_checkpoint",
|
||||||
)
|
)
|
||||||
submit = gr.Button(
|
checkpoint.change(
|
||||||
"Enqueue", elem_id=f"{id_part}_enqueue", variant="primary"
|
fn=self.on_checkpoint_changed, inputs=[checkpoint]
|
||||||
)
|
)
|
||||||
|
|
||||||
checkpoint.change(fn=self.on_checkpoint_changed, inputs=[checkpoint])
|
submit = gr.Button(
|
||||||
|
"Enqueue", elem_id=f"{id_part}_enqueue", variant="primary"
|
||||||
|
)
|
||||||
submit.click(**args)
|
submit.click(**args)
|
||||||
|
|
||||||
|
# relocation the enqueue button
|
||||||
|
root.children.pop()
|
||||||
|
if shared.opts.queue_button_placement == placement_between_prompt_and_generate:
|
||||||
|
if is_img2img:
|
||||||
|
# add to the iterrogate div
|
||||||
|
parent = generate.parent.parent.parent.children[1]
|
||||||
|
parent.add(row)
|
||||||
|
else:
|
||||||
|
# insert after the prompts
|
||||||
|
parent = generate.parent.parent.parent
|
||||||
|
row.parent = parent
|
||||||
|
parent.children.insert(1, row)
|
||||||
|
elif shared.opts.queue_button_placement == placement_under_generate:
|
||||||
|
# insert after the tools div
|
||||||
|
parent = generate.parent.parent
|
||||||
|
parent.children.insert(2, row)
|
||||||
|
else:
|
||||||
|
# insert after before the generate button
|
||||||
|
parent = generate.parent.parent
|
||||||
|
parent.children.insert(0, row)
|
||||||
|
|
||||||
if cnet_dependency is not None:
|
if cnet_dependency is not None:
|
||||||
cnet_fn_block = next(
|
cnet_fn_block = next(
|
||||||
fn
|
fn
|
||||||
|
|
@ -147,10 +175,6 @@ def get_checkpoint_choices():
|
||||||
return choices
|
return choices
|
||||||
|
|
||||||
|
|
||||||
def is_queue_paused():
|
|
||||||
return state_manager.get_value(AppStateKey.QueueState) == "paused"
|
|
||||||
|
|
||||||
|
|
||||||
def on_ui_tab(**_kwargs):
|
def on_ui_tab(**_kwargs):
|
||||||
global initialized
|
global initialized
|
||||||
if not initialized:
|
if not initialized:
|
||||||
|
|
@ -158,18 +182,25 @@ def on_ui_tab(**_kwargs):
|
||||||
init()
|
init()
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as scheduler_tab:
|
with gr.Blocks(analytics_enabled=False) as scheduler_tab:
|
||||||
|
gr.Textbox(
|
||||||
|
shared.opts.queue_button_placement,
|
||||||
|
elem_id="agent_scheduler_queue_button_placement",
|
||||||
|
show_label=False,
|
||||||
|
visible=False,
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
with gr.Row(elem_id="agent_scheduler_pending_tasks_wrapper"):
|
with gr.Row(elem_id="agent_scheduler_pending_tasks_wrapper"):
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
with gr.Group(elem_id="agent_scheduler_actions"):
|
with gr.Group(elem_id="agent_scheduler_actions"):
|
||||||
paused = is_queue_paused()
|
paused = shared.opts.queue_paused
|
||||||
|
|
||||||
pause = gr.Button(
|
gr.Button(
|
||||||
"Pause",
|
"Pause",
|
||||||
elem_id="agent_scheduler_action_pause",
|
elem_id="agent_scheduler_action_pause",
|
||||||
variant="stop",
|
variant="stop",
|
||||||
visible=not paused,
|
visible=not paused,
|
||||||
)
|
)
|
||||||
resume = gr.Button(
|
gr.Button(
|
||||||
"Resume",
|
"Resume",
|
||||||
elem_id="agent_scheduler_action_resume",
|
elem_id="agent_scheduler_action_resume",
|
||||||
variant="primary",
|
variant="primary",
|
||||||
|
|
@ -195,11 +226,53 @@ def on_ui_tab(**_kwargs):
|
||||||
return [(scheduler_tab, "Agent Scheduler", "agent_scheduler")]
|
return [(scheduler_tab, "Agent Scheduler", "agent_scheduler")]
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_settings():
|
||||||
|
section = ("agent_scheduler", "Agent Scheduler")
|
||||||
|
shared.opts.add_option(
|
||||||
|
"queue_paused",
|
||||||
|
shared.OptionInfo(
|
||||||
|
False,
|
||||||
|
"Disable queue auto processing",
|
||||||
|
gr.Checkbox,
|
||||||
|
{"interactive": True},
|
||||||
|
section=section,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
shared.opts.add_option(
|
||||||
|
"queue_button_placement",
|
||||||
|
shared.OptionInfo(
|
||||||
|
placement_above_generate,
|
||||||
|
"Queue button placement",
|
||||||
|
gr.Radio,
|
||||||
|
lambda: {
|
||||||
|
"choices": [
|
||||||
|
placement_above_generate,
|
||||||
|
placement_under_generate,
|
||||||
|
placement_between_prompt_and_generate,
|
||||||
|
]
|
||||||
|
},
|
||||||
|
section=section,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
shared.opts.add_option(
|
||||||
|
"queue_button_hide_checkpoint",
|
||||||
|
shared.OptionInfo(
|
||||||
|
True,
|
||||||
|
"Hide the checkpoint dropdown",
|
||||||
|
gr.Checkbox,
|
||||||
|
{},
|
||||||
|
section=section,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def on_app_started(block, _):
|
def on_app_started(block, _):
|
||||||
global task_runner
|
if block is not None:
|
||||||
task_runner = get_instance(block)
|
global task_runner
|
||||||
task_runner.execute_pending_tasks_threading()
|
task_runner = get_instance(block)
|
||||||
|
task_runner.execute_pending_tasks_threading()
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_ui_tabs(on_ui_tab)
|
script_callbacks.on_ui_tabs(on_ui_tab)
|
||||||
|
script_callbacks.on_ui_settings(on_ui_settings)
|
||||||
script_callbacks.on_app_started(on_app_started)
|
script_callbacks.on_app_started(on_app_started)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
module.exports = {
|
||||||
|
env: { browser: true, es2020: true },
|
||||||
|
extends: [
|
||||||
|
'eslint:recommended',
|
||||||
|
'plugin:@typescript-eslint/recommended',
|
||||||
|
'plugin:react-hooks/recommended',
|
||||||
|
],
|
||||||
|
parser: '@typescript-eslint/parser',
|
||||||
|
parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
|
||||||
|
plugins: ['react-refresh'],
|
||||||
|
rules: {
|
||||||
|
'react-refresh/only-export-components': 'warn',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
# Logs
|
||||||
|
logs
|
||||||
|
*.log
|
||||||
|
npm-debug.log*
|
||||||
|
yarn-debug.log*
|
||||||
|
yarn-error.log*
|
||||||
|
pnpm-debug.log*
|
||||||
|
lerna-debug.log*
|
||||||
|
|
||||||
|
node_modules
|
||||||
|
dist
|
||||||
|
dist-ssr
|
||||||
|
*.local
|
||||||
|
|
||||||
|
# Editor directories and files
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/extensions.json
|
||||||
|
.idea
|
||||||
|
.DS_Store
|
||||||
|
*.suo
|
||||||
|
*.ntvs*
|
||||||
|
*.njsproj
|
||||||
|
*.sln
|
||||||
|
*.sw?
|
||||||
|
|
@ -0,0 +1,35 @@
|
||||||
|
{
|
||||||
|
"name": "ui",
|
||||||
|
"private": true,
|
||||||
|
"version": "0.0.0",
|
||||||
|
"scripts": {
|
||||||
|
"dev": "vite",
|
||||||
|
"build": "tsc && yarn build:extension",
|
||||||
|
"build:extension": "vite build --config vite.extension.ts",
|
||||||
|
"lint": "eslint src --ext ts,tsx --report-unused-disable-directives --max-warnings 0",
|
||||||
|
"preview": "vite preview"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"ag-grid-community": "^29.3.5",
|
||||||
|
"notyf": "^3.10.0",
|
||||||
|
"react": "^18.2.0",
|
||||||
|
"react-dom": "^18.2.0",
|
||||||
|
"rxjs": "^7.8.1"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@types/node": "^18",
|
||||||
|
"@types/react": "^18.0.37",
|
||||||
|
"@types/react-dom": "^18.0.11",
|
||||||
|
"@typescript-eslint/eslint-plugin": "^5.59.0",
|
||||||
|
"@typescript-eslint/parser": "^5.59.0",
|
||||||
|
"@vitejs/plugin-react": "^4.0.0",
|
||||||
|
"autoprefixer": "^10.4.14",
|
||||||
|
"eslint": "^8.38.0",
|
||||||
|
"eslint-plugin-react-hooks": "^4.6.0",
|
||||||
|
"eslint-plugin-react-refresh": "^0.3.4",
|
||||||
|
"postcss": "^8.4.24",
|
||||||
|
"tailwindcss": "^3.3.2",
|
||||||
|
"typescript": "^5.0.2",
|
||||||
|
"vite": "^4.3.9"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
module.exports = {
|
||||||
|
plugins: {
|
||||||
|
tailwindcss: {},
|
||||||
|
autoprefixer: {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,160 @@
|
||||||
|
@tailwind components;
|
||||||
|
@tailwind utilities;
|
||||||
|
|
||||||
|
@layer components {
|
||||||
|
.ts-search {
|
||||||
|
@apply relative w-full max-w-xs ml-auto;
|
||||||
|
}
|
||||||
|
.ts-search-input {
|
||||||
|
@apply bg-gray-50 border border-gray-300 text-gray-900 text-sm !rounded-md focus:ring-blue-500 focus:border-blue-500 block !pl-10 !p-2 w-full dark:!bg-gray-700 dark:!border-gray-600 dark:!placeholder-gray-400 dark:!text-white dark:focus:ring-blue-500 dark:focus:border-blue-500;
|
||||||
|
}
|
||||||
|
|
||||||
|
.ts-search-icon {
|
||||||
|
@apply absolute inset-y-0 left-0 flex items-center dark:text-white pl-3 pointer-events-none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.ts-btn-action {
|
||||||
|
@apply inline-flex items-center !px-2 !py-1 !m-0 text-sm font-medium border focus:z-10 focus:ring-2 disabled:opacity-50 disabled:hover:!bg-transparent disabled:cursor-not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.ts-btn-run {
|
||||||
|
@apply !text-green-500 hover:!text-white border-green-500 hover:bg-green-600 rounded-l-md focus:ring-green-400 dark:border-green-500 dark:hover:bg-green-600 dark:focus:ring-green-900 disabled:hover:!text-green-500;
|
||||||
|
}
|
||||||
|
|
||||||
|
.ts-btn-delete {
|
||||||
|
@apply !text-red-500 hover:!text-white border-red-600 hover:bg-red-600 rounded-r-md focus:ring-red-300 dark:border-red-500 dark:hover:bg-red-600 dark:focus:ring-red-900;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes blink {
|
||||||
|
from,
|
||||||
|
to {
|
||||||
|
opacity: 0;
|
||||||
|
}
|
||||||
|
50% {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.task-running {
|
||||||
|
@apply !text-green-500;
|
||||||
|
animation: 1s blink ease infinite;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ========================================================================= */
|
||||||
|
|
||||||
|
#agent_scheduler_pending_tasks_wrapper {
|
||||||
|
gap: var(--layout-gap);
|
||||||
|
border: none;
|
||||||
|
box-shadow: none;
|
||||||
|
border-width: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (max-width: 1024px) {
|
||||||
|
#agent_scheduler_pending_tasks_wrapper {
|
||||||
|
flex-wrap: wrap;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_scheduler_pending_tasks_wrapper > div:last-child {
|
||||||
|
width: 100%;
|
||||||
|
max-width: 400px;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (min-width: 1920px) {
|
||||||
|
#agent_scheduler_pending_tasks_wrapper > div:last-child {
|
||||||
|
max-width: 512px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_scheduler_current_task_images {
|
||||||
|
width: 100%;
|
||||||
|
padding-top: 100%;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_scheduler_current_task_images > div {
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
right: 0;
|
||||||
|
bottom: 0;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_scheduler_pending_tasks_wrapper {
|
||||||
|
justify-content: flex-end;
|
||||||
|
gap: var(--layout-gap);
|
||||||
|
padding: 0 var(--layout-gap) var(--layout-gap) var(--layout-gap);
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_scheduler_pending_tasks_wrapper > button {
|
||||||
|
flex: 0 0 auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_scheduler_actions {
|
||||||
|
display: flex;
|
||||||
|
gap: var(--layout-gap);
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_scheduler_actions > button {
|
||||||
|
border-radius: var(--radius-lg) !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.ag-theme-alpine,
|
||||||
|
.ag-theme-alpine-dark {
|
||||||
|
--ag-row-height: 45px;
|
||||||
|
--ag-header-height: 45px;
|
||||||
|
/* --ag-grid-size: 6px; */
|
||||||
|
--ag-cell-horizontal-padding: calc(var(--ag-grid-size) * 2);
|
||||||
|
|
||||||
|
--body-text-color: 'inherit';
|
||||||
|
}
|
||||||
|
|
||||||
|
.cell-span {
|
||||||
|
border-bottom-color: var(--ag-border-color);
|
||||||
|
}
|
||||||
|
.cell-not-span {
|
||||||
|
opacity: 0;
|
||||||
|
}
|
||||||
|
.ag-row-hover .ag-cell {
|
||||||
|
background-color: transparent;
|
||||||
|
}
|
||||||
|
|
||||||
|
#txt2img_enqueue_wrapper,
|
||||||
|
#img2img_enqueue_wrapper {
|
||||||
|
min-width: 210px;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: calc(var(--layout-gap) / 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#txt2img_enqueue_wrapper > div:first-child,
|
||||||
|
#img2img_enqueue_wrapper > div:first-child {
|
||||||
|
flex-direction: row;
|
||||||
|
flex-wrap: nowrap;
|
||||||
|
align-items: flex-start;
|
||||||
|
flex: 0 0 auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
#txt2img_enqueue_wrapper .gradio-button,
|
||||||
|
#img2img_enqueue_wrapper .gradio-button,
|
||||||
|
#txt2img_enqueue_wrapper .gradio-dropdown .wrap-inner,
|
||||||
|
#img2img_enqueue_wrapper .gradio-dropdown .wrap-inner {
|
||||||
|
min-height: 36px;
|
||||||
|
max-height: 42px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#img2img_toprow .interrogate-col.has-queue-button {
|
||||||
|
min-width: unset !important;
|
||||||
|
flex-direction: row !important;
|
||||||
|
gap: calc(var(--layout-gap) / 2) !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
#img2img_toprow .interrogate-col.has-queue-button button {
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_scheduler_current_task_progress .livePreview {
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,470 @@
|
||||||
|
import * as rxjs from 'rxjs';
|
||||||
|
import type { Observer } from 'rxjs';
|
||||||
|
import { Grid, GridOptions } from 'ag-grid-community';
|
||||||
|
import { Notyf } from 'notyf';
|
||||||
|
|
||||||
|
import 'ag-grid-community/styles/ag-grid.css';
|
||||||
|
import 'ag-grid-community/styles/ag-theme-alpine.css';
|
||||||
|
import 'notyf/notyf.min.css';
|
||||||
|
import './index.css';
|
||||||
|
|
||||||
|
declare global {
|
||||||
|
var country: string;
|
||||||
|
function gradioApp(): HTMLElement;
|
||||||
|
function randomId(): string;
|
||||||
|
function get_tab_index(name: string): number;
|
||||||
|
function create_submit_args(args: IArguments): any[];
|
||||||
|
function requestProgress(
|
||||||
|
id: string,
|
||||||
|
progressContainer: HTMLElement,
|
||||||
|
imagesContainer: HTMLElement,
|
||||||
|
onDone: () => void,
|
||||||
|
): void;
|
||||||
|
function onUiLoaded(callback: () => void): void;
|
||||||
|
function submit_enqueue(): any[];
|
||||||
|
function submit_enqueue_img2img(): any[];
|
||||||
|
}
|
||||||
|
|
||||||
|
type Task = {
|
||||||
|
id: string;
|
||||||
|
api_task_id: string;
|
||||||
|
type: string;
|
||||||
|
status: string;
|
||||||
|
params: Record<string, any>;
|
||||||
|
priority: number;
|
||||||
|
result: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
type AppState = {
|
||||||
|
current_task_id: string | null;
|
||||||
|
total_pending_tasks: number;
|
||||||
|
pending_tasks: Task[];
|
||||||
|
paused: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
function initTaskScheduler() {
|
||||||
|
const notyf = new Notyf();
|
||||||
|
const subject = new rxjs.Subject<AppState>();
|
||||||
|
|
||||||
|
const store = {
|
||||||
|
subject,
|
||||||
|
subscribe: (callback: Partial<Observer<[AppState, AppState]>>) => {
|
||||||
|
return store.subject.pipe(rxjs.pairwise()).subscribe(callback);
|
||||||
|
},
|
||||||
|
refresh: async () => {
|
||||||
|
return fetch('/agent-scheduler/v1/queue?limit=1000')
|
||||||
|
.then((response) => response.json())
|
||||||
|
.then((data: AppState) => {
|
||||||
|
const pending_tasks = data.pending_tasks.map((item) => ({
|
||||||
|
...item,
|
||||||
|
params: JSON.parse(item.params as any),
|
||||||
|
status: item.id === data.current_task_id ? 'running' : 'pending',
|
||||||
|
}));
|
||||||
|
store.subject.next({
|
||||||
|
...data,
|
||||||
|
pending_tasks,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
},
|
||||||
|
pauseQueue: async () => {
|
||||||
|
return fetch('/agent-scheduler/v1/pause', { method: 'POST' })
|
||||||
|
.then((response) => response.json())
|
||||||
|
.then((data) => {
|
||||||
|
if (data.success) {
|
||||||
|
notyf.success(data.message);
|
||||||
|
} else {
|
||||||
|
notyf.error(data.message);
|
||||||
|
}
|
||||||
|
|
||||||
|
return store.refresh();
|
||||||
|
});
|
||||||
|
},
|
||||||
|
resumeQueue: async () => {
|
||||||
|
return fetch('/agent-scheduler/v1/resume', { method: 'POST' })
|
||||||
|
.then((response) => response.json())
|
||||||
|
.then((data) => {
|
||||||
|
if (data.success) {
|
||||||
|
notyf.success(data.message);
|
||||||
|
} else {
|
||||||
|
notyf.error(data.message);
|
||||||
|
}
|
||||||
|
|
||||||
|
return store.refresh();
|
||||||
|
});
|
||||||
|
},
|
||||||
|
runTask: async (id: string) => {
|
||||||
|
return fetch(`/agent-scheduler/v1/run/${id}`, { method: 'POST' })
|
||||||
|
.then((response) => response.json())
|
||||||
|
.then((data) => {
|
||||||
|
if (data.success) {
|
||||||
|
notyf.success(data.message);
|
||||||
|
} else {
|
||||||
|
notyf.error(data.message);
|
||||||
|
}
|
||||||
|
|
||||||
|
return store.refresh();
|
||||||
|
});
|
||||||
|
},
|
||||||
|
deleteTask: async (id: string) => {
|
||||||
|
return fetch(`/agent-scheduler/v1/delete/${id}`, { method: 'POST' })
|
||||||
|
.then((response) => response.json())
|
||||||
|
.then((data) => {
|
||||||
|
if (data.success) {
|
||||||
|
notyf.success(data.message);
|
||||||
|
} else {
|
||||||
|
notyf.error(data.message);
|
||||||
|
}
|
||||||
|
|
||||||
|
return store.refresh();
|
||||||
|
});
|
||||||
|
},
|
||||||
|
moveTask: async (id: string, overId: string) => {
|
||||||
|
return fetch(`/agent-scheduler/v1/move/${id}/${overId}`, { method: 'POST' })
|
||||||
|
.then((response) => response.json())
|
||||||
|
.then((data) => {
|
||||||
|
if (data.success) {
|
||||||
|
notyf.success(data.message);
|
||||||
|
} else {
|
||||||
|
notyf.error(data.message);
|
||||||
|
}
|
||||||
|
|
||||||
|
return store.refresh();
|
||||||
|
});
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
store.subject.next({
|
||||||
|
current_task_id: null,
|
||||||
|
total_pending_tasks: 0,
|
||||||
|
pending_tasks: [],
|
||||||
|
paused: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
window.submit_enqueue = function submit_enqueue() {
|
||||||
|
var id = randomId();
|
||||||
|
var res = create_submit_args(arguments);
|
||||||
|
res[0] = id;
|
||||||
|
|
||||||
|
const btnEnqueue = document.querySelector('#txt2img_enqueue');
|
||||||
|
if (btnEnqueue) {
|
||||||
|
btnEnqueue.innerHTML = 'Queued';
|
||||||
|
setTimeout(() => {
|
||||||
|
btnEnqueue.innerHTML = 'Enqueue';
|
||||||
|
store.refresh();
|
||||||
|
}, 1000);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
};
|
||||||
|
|
||||||
|
window.submit_enqueue_img2img = function submit_enqueue_img2img() {
|
||||||
|
var id = randomId();
|
||||||
|
var res = create_submit_args(arguments);
|
||||||
|
res[0] = id;
|
||||||
|
res[1] = get_tab_index('mode_img2img');
|
||||||
|
|
||||||
|
const btnEnqueue = document.querySelector('#img2img_enqueue');
|
||||||
|
if (btnEnqueue) {
|
||||||
|
btnEnqueue.innerHTML = 'Queued';
|
||||||
|
setTimeout(() => {
|
||||||
|
btnEnqueue.innerHTML = 'Enqueue';
|
||||||
|
store.refresh();
|
||||||
|
}, 1000);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
};
|
||||||
|
|
||||||
|
// detect queue button placement
|
||||||
|
const interrogateCol: HTMLDivElement = gradioApp().querySelector('.interrogate-col')!;
|
||||||
|
if (interrogateCol.childElementCount > 2) {
|
||||||
|
interrogateCol.classList.add('has-queue-button');
|
||||||
|
}
|
||||||
|
|
||||||
|
// watch for tab activation
|
||||||
|
const observer = new MutationObserver(function (mutationsList) {
|
||||||
|
const styleChange = mutationsList.find((mutation) => mutation.attributeName === 'style');
|
||||||
|
if (styleChange) {
|
||||||
|
const tab = styleChange.target as HTMLElement;
|
||||||
|
if (tab.style.display === 'block') {
|
||||||
|
store.refresh();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
observer.observe(document.getElementById('tab_agent_scheduler')!, { attributes: true });
|
||||||
|
|
||||||
|
// init actions
|
||||||
|
const refreshButton = gradioApp().querySelector('#agent_scheduler_action_refresh')!;
|
||||||
|
const pauseButton = gradioApp().querySelector('#agent_scheduler_action_pause')!;
|
||||||
|
const resumeButton = gradioApp().querySelector('#agent_scheduler_action_resume')!;
|
||||||
|
refreshButton.addEventListener('click', store.refresh);
|
||||||
|
pauseButton.addEventListener('click', store.pauseQueue);
|
||||||
|
resumeButton.addEventListener('click', store.resumeQueue);
|
||||||
|
|
||||||
|
const searchContainer = gradioApp().querySelector('#agent_scheduler_action_search')!;
|
||||||
|
searchContainer.className = 'ts-search';
|
||||||
|
searchContainer.innerHTML = `
|
||||||
|
<div class="ts-search-icon">
|
||||||
|
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
|
||||||
|
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||||
|
<path d="M10 10m-7 0a7 7 0 1 0 14 0a7 7 0 1 0 -14 0"/>
|
||||||
|
<path d="M21 21l-6 -6"/>
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
<input type="text" id="agent_scheduler_search_input" class="ts-search-input" placeholder="Search" required>
|
||||||
|
`;
|
||||||
|
|
||||||
|
// watch for current task id change
|
||||||
|
const onTaskIdChange = (id: string | null) => {
|
||||||
|
if (id) {
|
||||||
|
requestProgress(
|
||||||
|
id,
|
||||||
|
gradioApp().querySelector('#agent_scheduler_current_task_progress')!,
|
||||||
|
gradioApp().querySelector('#agent_scheduler_current_task_images')!,
|
||||||
|
() => {
|
||||||
|
setTimeout(() => {
|
||||||
|
store.refresh();
|
||||||
|
}, 1000);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
store.subscribe({
|
||||||
|
next: ([prev, curr]) => {
|
||||||
|
if (prev.current_task_id !== curr.current_task_id) {
|
||||||
|
onTaskIdChange(curr.current_task_id);
|
||||||
|
}
|
||||||
|
if (curr.paused) {
|
||||||
|
pauseButton.classList.add('hide');
|
||||||
|
resumeButton.classList.remove('hide');
|
||||||
|
} else {
|
||||||
|
pauseButton.classList.remove('hide');
|
||||||
|
resumeButton.classList.add('hide');
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// init grid
|
||||||
|
const deleteIcon = `
|
||||||
|
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
|
||||||
|
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||||
|
<path d="M4 7l16 0"/>
|
||||||
|
<path d="M10 11l0 6"/>
|
||||||
|
<path d="M14 11l0 6"/>
|
||||||
|
<path d="M5 7l1 12a2 2 0 0 0 2 2h8a2 2 0 0 0 2 -2l1 -12"/>
|
||||||
|
<path d="M9 7v-3a1 1 0 0 1 1 -1h4a1 1 0 0 1 1 1v3"/>
|
||||||
|
</svg>`;
|
||||||
|
const cancelIcon = `
|
||||||
|
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
|
||||||
|
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||||
|
<path d="M18 6l-12 12"/>
|
||||||
|
<path d="M6 6l12 12"/>
|
||||||
|
</svg>
|
||||||
|
`;
|
||||||
|
const pendingTasksGridOptions: GridOptions<Task> = {
|
||||||
|
// domLayout: 'autoHeight',
|
||||||
|
// default col def properties get applied to all columns
|
||||||
|
defaultColDef: {
|
||||||
|
sortable: false,
|
||||||
|
filter: true,
|
||||||
|
resizable: true,
|
||||||
|
suppressMenu: true,
|
||||||
|
},
|
||||||
|
// each entry here represents one column
|
||||||
|
columnDefs: [
|
||||||
|
{
|
||||||
|
field: 'id',
|
||||||
|
headerName: 'Task Id',
|
||||||
|
minWidth: 240,
|
||||||
|
maxWidth: 240,
|
||||||
|
pinned: 'left',
|
||||||
|
rowDrag: true,
|
||||||
|
cellClass: ({ data }) => [
|
||||||
|
data?.status === 'running' ? 'task-running' : '',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
field: 'type',
|
||||||
|
headerName: 'Type',
|
||||||
|
minWidth: 80,
|
||||||
|
maxWidth: 80,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
field: 'priority',
|
||||||
|
headerName: 'Priority',
|
||||||
|
hide: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
headerName: 'Params',
|
||||||
|
children: [
|
||||||
|
{
|
||||||
|
field: 'params.prompt',
|
||||||
|
headerName: 'Prompt',
|
||||||
|
minWidth: 400,
|
||||||
|
autoHeight: true,
|
||||||
|
wrapText: true,
|
||||||
|
cellStyle: { 'line-height': '24px', 'padding-top': '8px', 'padding-bottom': '8px' },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
field: 'params.negative_prompt',
|
||||||
|
headerName: 'Negative Prompt',
|
||||||
|
minWidth: 400,
|
||||||
|
autoHeight: true,
|
||||||
|
wrapText: true,
|
||||||
|
cellStyle: { 'line-height': '24px', 'padding-top': '8px', 'padding-bottom': '8px' },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
field: 'params.checkpoint',
|
||||||
|
headerName: 'Checkpoint',
|
||||||
|
minWidth: 150,
|
||||||
|
maxWidth: 300,
|
||||||
|
valueFormatter: ({ value }) => value || 'System',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
field: 'params.sampler_name',
|
||||||
|
headerName: 'Sampler',
|
||||||
|
width: 150,
|
||||||
|
minWidth: 150,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
field: 'params.steps',
|
||||||
|
headerName: 'Steps',
|
||||||
|
minWidth: 80,
|
||||||
|
maxWidth: 80,
|
||||||
|
filter: 'agNumberColumnFilter',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
field: 'params.cfg_scale',
|
||||||
|
headerName: 'CFG Scale',
|
||||||
|
width: 100,
|
||||||
|
minWidth: 100,
|
||||||
|
filter: 'agNumberColumnFilter',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
field: 'params.size',
|
||||||
|
headerName: 'Size',
|
||||||
|
minWidth: 110,
|
||||||
|
maxWidth: 110,
|
||||||
|
valueGetter: ({ data }) => (data ? `${data.params.width}x${data.params.height}` : ''),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
field: 'params.batch',
|
||||||
|
headerName: 'Batching',
|
||||||
|
minWidth: 100,
|
||||||
|
maxWidth: 100,
|
||||||
|
valueGetter: ({ data }) =>
|
||||||
|
data ? `${data.params.n_iter}x${data.params.batch_size}` : '1x1',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{ field: 'created_at', headerName: 'Date', minWidth: 200 },
|
||||||
|
{
|
||||||
|
headerName: 'Action',
|
||||||
|
pinned: 'right',
|
||||||
|
minWidth: 110,
|
||||||
|
maxWidth: 110,
|
||||||
|
resizable: false,
|
||||||
|
valueGetter: ({ data }) => data?.id,
|
||||||
|
cellRenderer: ({ api, value, data }: any) => {
|
||||||
|
const html = `
|
||||||
|
<div class="inline-flex rounded-md shadow-sm mt-1.5" role="group">
|
||||||
|
<button type="button" ${
|
||||||
|
data.status === 'running' ? 'disabled' : ''
|
||||||
|
} class="ts-btn-action ts-btn-run">
|
||||||
|
<svg width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
|
||||||
|
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||||||
|
<path d="M7 4v16l13 -8z"/>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
<button type="button" class="ts-btn-action ts-btn-delete">
|
||||||
|
${data.status === 'pending' ? deleteIcon : cancelIcon}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
|
||||||
|
const placeholder = document.createElement('div');
|
||||||
|
placeholder.innerHTML = html;
|
||||||
|
const node = placeholder.firstElementChild!;
|
||||||
|
|
||||||
|
const btnRun = node.querySelector('button.ts-btn-run')!;
|
||||||
|
btnRun.addEventListener('click', () => {
|
||||||
|
console.log('run', value);
|
||||||
|
api.showLoadingOverlay();
|
||||||
|
store.runTask(value).then(() => api.hideOverlay());
|
||||||
|
});
|
||||||
|
|
||||||
|
const btnDelete = node.querySelector('button.ts-btn-delete')!;
|
||||||
|
btnDelete.addEventListener('click', () => {
|
||||||
|
console.log('delete', value);
|
||||||
|
api.showLoadingOverlay();
|
||||||
|
store.deleteTask(value).then(() => api.hideOverlay());
|
||||||
|
});
|
||||||
|
|
||||||
|
return node;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
getRowId: ({ data }) => data.id,
|
||||||
|
|
||||||
|
rowData: [],
|
||||||
|
rowSelection: 'single', // allow rows to be selected
|
||||||
|
animateRows: true, // have rows animate to new positions when sorted
|
||||||
|
pagination: true,
|
||||||
|
paginationPageSize: 10,
|
||||||
|
suppressCopyRowsToClipboard: true,
|
||||||
|
suppressRowTransform: true,
|
||||||
|
suppressRowClickSelection: true,
|
||||||
|
enableBrowserTooltips: true,
|
||||||
|
onGridReady: ({ api }) => {
|
||||||
|
// init quick search input
|
||||||
|
const searchInput: HTMLInputElement = searchContainer.querySelector(
|
||||||
|
'input#agent_scheduler_search_input',
|
||||||
|
)!;
|
||||||
|
rxjs
|
||||||
|
.fromEvent(searchInput, 'keyup')
|
||||||
|
.pipe(rxjs.debounce(() => rxjs.interval(200)))
|
||||||
|
.subscribe((e) => {
|
||||||
|
api.setQuickFilter((e.target as HTMLInputElement).value);
|
||||||
|
});
|
||||||
|
|
||||||
|
store.subscribe({
|
||||||
|
next: ([_, newState]) => {
|
||||||
|
api.setRowData(newState.pending_tasks);
|
||||||
|
|
||||||
|
if (newState.current_task_id) {
|
||||||
|
const node = api.getRowNode(newState.current_task_id);
|
||||||
|
if (node) {
|
||||||
|
api.refreshCells({ rowNodes: [node], force: true });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
api.sizeColumnsToFit();
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// refresh the state
|
||||||
|
store.refresh();
|
||||||
|
},
|
||||||
|
onRowDragEnd: ({ api, node, overNode }) => {
|
||||||
|
const id = node.data?.id;
|
||||||
|
const overId = overNode?.data?.id;
|
||||||
|
if (id && overId && id !== overId) {
|
||||||
|
api.showLoadingOverlay();
|
||||||
|
store.moveTask(id, overId).then(() => api.hideOverlay());
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const eGridDiv = gradioApp().querySelector<HTMLDivElement>(
|
||||||
|
'#agent_scheduler_pending_tasks_grid',
|
||||||
|
)!;
|
||||||
|
if (document.querySelector('.dark')) {
|
||||||
|
eGridDiv.className = 'ag-theme-alpine-dark';
|
||||||
|
}
|
||||||
|
eGridDiv.style.height = window.innerHeight - 240 + 'px';
|
||||||
|
new Grid(eGridDiv, pendingTasksGridOptions);
|
||||||
|
|
||||||
|
store.refresh();
|
||||||
|
}
|
||||||
|
|
||||||
|
onUiLoaded(initTaskScheduler);
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
/// <reference types="vite/client" />
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
module.exports = {
|
||||||
|
content: ["./src/**/*.{js,ts,tsx}"],
|
||||||
|
darkMode: 'class',
|
||||||
|
theme: {
|
||||||
|
extend: {},
|
||||||
|
},
|
||||||
|
plugins: [],
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"target": "ES2020",
|
||||||
|
"useDefineForClassFields": true,
|
||||||
|
"lib": ["ES2020", "DOM", "DOM.Iterable"],
|
||||||
|
"module": "ESNext",
|
||||||
|
"skipLibCheck": true,
|
||||||
|
|
||||||
|
/* Bundler mode */
|
||||||
|
"moduleResolution": "bundler",
|
||||||
|
"allowImportingTsExtensions": true,
|
||||||
|
"resolveJsonModule": true,
|
||||||
|
"isolatedModules": true,
|
||||||
|
"noEmit": true,
|
||||||
|
"jsx": "react-jsx",
|
||||||
|
|
||||||
|
/* Linting */
|
||||||
|
"strict": true,
|
||||||
|
"noUnusedLocals": true,
|
||||||
|
"noUnusedParameters": true,
|
||||||
|
"noFallthroughCasesInSwitch": true
|
||||||
|
},
|
||||||
|
"include": ["src"],
|
||||||
|
"references": [{ "path": "./tsconfig.node.json" }]
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"composite": true,
|
||||||
|
"skipLibCheck": true,
|
||||||
|
"module": "ESNext",
|
||||||
|
"moduleResolution": "bundler",
|
||||||
|
"allowSyntheticDefaultImports": true
|
||||||
|
},
|
||||||
|
"include": ["vite.config.ts"]
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
import { defineConfig } from 'vite';
|
||||||
|
import react from '@vitejs/plugin-react';
|
||||||
|
|
||||||
|
// https://vitejs.dev/config/
|
||||||
|
export default defineConfig({
|
||||||
|
plugins: [react()],
|
||||||
|
build: {
|
||||||
|
outDir: '../',
|
||||||
|
rollupOptions: {
|
||||||
|
input: {
|
||||||
|
main: 'index.html',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
lib: {
|
||||||
|
name: 'agent-scheduler',
|
||||||
|
entry: 'src/extension/index.ts',
|
||||||
|
fileName: 'javascript/extension',
|
||||||
|
formats: ['es']
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
import { defineConfig } from 'vite';
|
||||||
|
|
||||||
|
// https://vitejs.dev/config/
|
||||||
|
export default defineConfig({
|
||||||
|
build: {
|
||||||
|
outDir: '../',
|
||||||
|
copyPublicDir: false,
|
||||||
|
lib: {
|
||||||
|
name: 'agent-scheduler',
|
||||||
|
entry: 'src/extension/index.ts',
|
||||||
|
fileName: 'javascript/extension',
|
||||||
|
formats: ['es']
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue