feature update
* Add api to download task's generated images * Add setting to render extension UI below the main UI * Display task datetime in local timezone * Persist the grid state (columns order, sorting) for next session * Bugs fixing and code improvementspull/71/head
parent
45dfe5977e
commit
ce2b9b4eb9
28
README.md
28
README.md
|
|
@ -103,7 +103,9 @@ By default, queued tasks use the currently loaded checkpoint. However, changing
|
|||
|
||||
All the functionality of this extension can be accessed through HTTP APIs. You can access the API documentation via `http://127.0.0.1:7860/docs`. Remember to include `--api` in your startup arguments.
|
||||
|
||||

|
||||

|
||||
|
||||
#### Queue Task
|
||||
|
||||
The two apis `/agent-scheduler/v1/queue/txt2img` and `/agent-scheduler/v1/queue/img2img` support all the parameters of the original webui apis. These apis response the task id, which can be used to perform updates later.
|
||||
|
||||
|
|
@ -113,6 +115,30 @@ The two apis `/agent-scheduler/v1/queue/txt2img` and `/agent-scheduler/v1/queue/
|
|||
}
|
||||
```
|
||||
|
||||
#### Download Results
|
||||
|
||||
Use api `/agent-scheduler/v1/results/{id}` to get the generated images. The api supports two response format:
|
||||
|
||||
- json with base64 encoded
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": [
|
||||
{
|
||||
"image": "data:image/png;base64,iVBORw0KGgoAAAAN...",
|
||||
"infotext": "1girl\nNegative prompt: EasyNegative, badhandv4..."
|
||||
},
|
||||
{
|
||||
"image": "data:image/png;base64,iVBORw0KGgoAAAAN...",
|
||||
"infotext": "1girl\nNegative prompt: EasyNegative, badhandv4..."
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
- zip file with querystring `zip=true`
|
||||
|
||||
## Road Map
|
||||
|
||||
To list possible feature upgrades for this extension
|
||||
|
|
|
|||
|
|
@ -1,6 +1,12 @@
|
|||
import io
|
||||
import json
|
||||
import threading
|
||||
from uuid import uuid4
|
||||
from zipfile import ZipFile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from gradio.routes import App
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from modules import shared, progress
|
||||
|
||||
|
|
@ -15,7 +21,7 @@ from .models import (
|
|||
)
|
||||
from .task_runner import TaskRunner
|
||||
from .helpers import log
|
||||
from .task_helpers import serialize_api_task_args
|
||||
from .task_helpers import encode_image_to_base64
|
||||
|
||||
|
||||
def regsiter_apis(app: App, task_runner: TaskRunner):
|
||||
|
|
@ -23,16 +29,15 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
|
|||
|
||||
@app.post("/agent-scheduler/v1/queue/txt2img", response_model=QueueTaskResponse)
|
||||
def queue_txt2img(body: Txt2ImgApiTaskArgs):
|
||||
params = body.dict()
|
||||
task_id = str(uuid4())
|
||||
checkpoint = params.pop("model_hash", None)
|
||||
task_args = serialize_api_task_args(
|
||||
params,
|
||||
is_img2img=False,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
args = body.dict()
|
||||
checkpoint = args.pop("model_hash", None)
|
||||
task_runner.register_api_task(
|
||||
task_id, api_task_id=False, is_img2img=False, args=task_args
|
||||
task_id,
|
||||
api_task_id=None,
|
||||
is_img2img=False,
|
||||
args=args,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
task_runner.execute_pending_tasks_threading()
|
||||
|
||||
|
|
@ -40,21 +45,30 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
|
|||
|
||||
@app.post("/agent-scheduler/v1/queue/img2img", response_model=QueueTaskResponse)
|
||||
def queue_img2img(body: Img2ImgApiTaskArgs):
|
||||
params = body.dict()
|
||||
task_id = str(uuid4())
|
||||
checkpoint = params.pop("model_hash", None)
|
||||
task_args = serialize_api_task_args(
|
||||
params,
|
||||
is_img2img=True,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
args = body.dict()
|
||||
checkpoint = args.pop("model_hash", None)
|
||||
task_runner.register_api_task(
|
||||
task_id, api_task_id=False, is_img2img=True, args=task_args
|
||||
task_id,
|
||||
api_task_id=None,
|
||||
is_img2img=True,
|
||||
args=args,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
task_runner.execute_pending_tasks_threading()
|
||||
|
||||
return QueueTaskResponse(task_id=task_id)
|
||||
|
||||
def format_task_args(task):
|
||||
task_args = TaskRunner.instance.parse_task_args(task, deserialization=False)
|
||||
named_args = task_args.named_args
|
||||
named_args["checkpoint"] = task_args.checkpoint
|
||||
# remove unused args to reduce payload size
|
||||
named_args.pop("alwayson_scripts", None)
|
||||
named_args.pop("script_args", None)
|
||||
named_args.pop("init_images", None)
|
||||
return named_args
|
||||
|
||||
@app.get("/agent-scheduler/v1/queue", response_model=QueueStatusResponse)
|
||||
def queue_status_api(limit: int = 20, offset: int = 0):
|
||||
current_task_id = progress.current_task
|
||||
|
|
@ -64,14 +78,9 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
|
|||
)
|
||||
parsed_tasks = []
|
||||
for task in pending_tasks:
|
||||
task_args = TaskRunner.instance.parse_task_args(
|
||||
task.params, task.script_params, deserialization=False
|
||||
)
|
||||
named_args = task_args.named_args
|
||||
named_args["checkpoint"] = task_args.checkpoint
|
||||
|
||||
params = format_task_args(task)
|
||||
task_data = task.dict()
|
||||
task_data["params"] = named_args
|
||||
task_data["params"] = params
|
||||
if task.id == current_task_id:
|
||||
task_data["status"] = "running"
|
||||
|
||||
|
|
@ -104,14 +113,9 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
|
|||
)
|
||||
parsed_tasks = []
|
||||
for task in tasks:
|
||||
task_args = TaskRunner.instance.parse_task_args(
|
||||
task.params, task.script_params, deserialization=False
|
||||
)
|
||||
named_args = task_args.named_args
|
||||
named_args["checkpoint"] = task_args.checkpoint
|
||||
|
||||
params = format_task_args(task)
|
||||
task_data = task.dict()
|
||||
task_data["params"] = named_args
|
||||
task_data["params"] = params
|
||||
parsed_tasks.append(TaskModel(**task_data))
|
||||
|
||||
return HistoryResponse(
|
||||
|
|
@ -220,6 +224,50 @@ def regsiter_apis(app: App, task_runner: TaskRunner):
|
|||
task_manager.update_task(id, name=name)
|
||||
return {"success": True, "message": f"Task {id} is renamed"}
|
||||
|
||||
@app.get("/agent-scheduler/v1/results/{id}")
|
||||
def get_task_results(id: str, zip: Optional[bool] = False):
|
||||
task = task_manager.get_task(id)
|
||||
if task is None:
|
||||
return {"success": False, "message": f"Task not found"}
|
||||
|
||||
if task.status != TaskStatus.DONE:
|
||||
return {"success": False, "message": f"Task is {task.status.value}"}
|
||||
|
||||
if task.result is None:
|
||||
return {"success": False, "message": f"Task result is not available"}
|
||||
|
||||
result: dict = json.loads(task.result)
|
||||
infotexts = result["infotexts"]
|
||||
|
||||
if zip:
|
||||
zip_buffer = io.BytesIO()
|
||||
|
||||
# Create a new zip file in the in-memory buffer
|
||||
with ZipFile(zip_buffer, "w") as zip_file:
|
||||
# Loop through the files in the directory and add them to the zip file
|
||||
for image in result["images"]:
|
||||
if Path(image).is_file():
|
||||
zip_file.write(Path(image), Path(image).name)
|
||||
|
||||
# Reset the buffer position to the beginning to avoid truncation issues
|
||||
zip_buffer.seek(0)
|
||||
|
||||
# Return the in-memory buffer as a streaming response with the appropriate headers
|
||||
return StreamingResponse(
|
||||
zip_buffer,
|
||||
media_type="application/zip",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=results-{id}.zip"
|
||||
},
|
||||
)
|
||||
else:
|
||||
data = [
|
||||
{"image": encode_image_to_base64(image), "infotext": infotexts[i]}
|
||||
for i, image in enumerate(result["images"])
|
||||
]
|
||||
|
||||
return {"success": True, "data": data}
|
||||
|
||||
@app.post("/agent-scheduler/v1/pause")
|
||||
def pause_queue():
|
||||
shared.opts.queue_paused = True
|
||||
|
|
|
|||
|
|
@ -31,35 +31,8 @@ class Task(TaskModel):
|
|||
script_params: bytes = None
|
||||
params: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = "",
|
||||
api_task_id: str = None,
|
||||
name: str = None,
|
||||
type: str = "unknown",
|
||||
params: str = "",
|
||||
priority: int = None,
|
||||
status: str = TaskStatus.PENDING.value,
|
||||
result: str = None,
|
||||
bookmarked: bool = False,
|
||||
created_at: Optional[datetime] = None,
|
||||
updated_at: Optional[datetime] = None,
|
||||
):
|
||||
priority = priority if priority else int(datetime.utcnow().timestamp() * 1000)
|
||||
|
||||
super().__init__(
|
||||
id=id,
|
||||
api_task_id=api_task_id,
|
||||
name=name,
|
||||
type=type,
|
||||
params=params,
|
||||
status=status,
|
||||
priority=priority,
|
||||
result=result,
|
||||
bookmarked=bookmarked,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
def __init__(self, priority=int(datetime.utcnow().timestamp() * 1000), **kwargs):
|
||||
super().__init__(priority=priority, **kwargs)
|
||||
|
||||
class Config(TaskModel.__config__):
|
||||
exclude = ["script_params"]
|
||||
|
|
@ -104,7 +77,7 @@ class TaskTable(Base):
|
|||
type = Column(String(20), nullable=False) # txt2img or img2txt
|
||||
params = Column(Text, nullable=False) # task args
|
||||
script_params = Column(LargeBinary, nullable=False) # script args
|
||||
priority = Column(Integer, nullable=False, default=datetime.now)
|
||||
priority = Column(Integer, nullable=False)
|
||||
status = Column(
|
||||
String(20), nullable=False, default="pending"
|
||||
) # pending, running, done, failed
|
||||
|
|
@ -144,6 +117,7 @@ class TaskManager(BaseTableManager):
|
|||
type: str = None,
|
||||
status: Union[str, list[str]] = None,
|
||||
bookmarked: bool = None,
|
||||
api_task_id: str = None,
|
||||
limit: int = None,
|
||||
offset: int = None,
|
||||
order: str = "asc",
|
||||
|
|
@ -160,6 +134,9 @@ class TaskManager(BaseTableManager):
|
|||
else:
|
||||
query = query.filter(TaskTable.status == status)
|
||||
|
||||
if api_task_id:
|
||||
query = query.filter(TaskTable.api_task_id == api_task_id)
|
||||
|
||||
if bookmarked == True:
|
||||
query = query.filter(TaskTable.bookmarked == bookmarked)
|
||||
else:
|
||||
|
|
@ -189,6 +166,7 @@ class TaskManager(BaseTableManager):
|
|||
self,
|
||||
type: str = None,
|
||||
status: Union[str, list[str]] = None,
|
||||
api_task_id: str = None,
|
||||
) -> int:
|
||||
session = Session(self.engine)
|
||||
try:
|
||||
|
|
@ -202,6 +180,9 @@ class TaskManager(BaseTableManager):
|
|||
else:
|
||||
query = query.filter(TaskTable.status == status)
|
||||
|
||||
if api_task_id:
|
||||
query = query.filter(TaskTable.api_task_id == api_task_id)
|
||||
|
||||
return query.count()
|
||||
except Exception as e:
|
||||
print(f"Exception counting tasks from database: {e}")
|
||||
|
|
@ -262,7 +243,7 @@ class TaskManager(BaseTableManager):
|
|||
result = session.get(TaskTable, id)
|
||||
if result:
|
||||
if priority == 0:
|
||||
result.priority = self.__get_min_priority() - 1
|
||||
result.priority = self.__get_min_priority(status=TaskStatus.PENDING) - 1
|
||||
elif priority == -1:
|
||||
result.priority = int(datetime.utcnow().timestamp() * 1000)
|
||||
else:
|
||||
|
|
@ -316,10 +297,14 @@ class TaskManager(BaseTableManager):
|
|||
finally:
|
||||
session.close()
|
||||
|
||||
def __get_min_priority(self) -> int:
|
||||
def __get_min_priority(self, status: str = None) -> int:
|
||||
session = Session(self.engine)
|
||||
try:
|
||||
min_priority = session.query(func.min(TaskTable.priority)).scalar()
|
||||
query = session.query(func.min(TaskTable.priority))
|
||||
if status is not None:
|
||||
query = query.filter(TaskTable.status == status)
|
||||
|
||||
min_priority = query.scalar()
|
||||
return min_priority if min_priority else 0
|
||||
except Exception as e:
|
||||
print(f"Exception getting min priority from database: {e}")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from gradio.blocks import Block, BlockContext
|
|||
|
||||
if not logging.getLogger().hasHandlers():
|
||||
# Logging is not set up
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
|
||||
log = logging.getLogger("sd")
|
||||
|
||||
|
|
@ -78,3 +78,43 @@ def detect_control_net(root: gr.Blocks, submit: gr.Button):
|
|||
UiControlNetUnit = type(output.value)
|
||||
|
||||
return UiControlNetUnit
|
||||
|
||||
|
||||
def get_dict_attribute(dict_inst: dict, name_string: str, default=None):
|
||||
nested_keys = name_string.split(".")
|
||||
value = dict_inst
|
||||
|
||||
for key in nested_keys:
|
||||
value = value.get(key, None)
|
||||
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def set_dict_attribute(dict_inst: dict, name_string: str, value):
|
||||
"""
|
||||
Set an attribute to a dictionary using dot notation.
|
||||
If the attribute does not already exist, it will create a nested dictionary.
|
||||
|
||||
Parameters:
|
||||
- dict_inst: the dictionary instance to set the attribute
|
||||
- name_string: the attribute name in dot notation (ex: 'attribute.name')
|
||||
- value: the value to set for the attribute
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Split the attribute names by dot
|
||||
name_list = name_string.split(".")
|
||||
|
||||
# Traverse the dictionary and create a nested dictionary if necessary
|
||||
current_dict = dict_inst
|
||||
for name in name_list[:-1]:
|
||||
if name not in current_dict:
|
||||
current_dict[name] = {}
|
||||
current_dict = current_dict[name]
|
||||
|
||||
# Set the final attribute to its value
|
||||
current_dict[name_list[-1]] = value
|
||||
|
|
|
|||
|
|
@ -32,12 +32,14 @@ class TaskModel(BaseModel):
|
|||
name: Optional[str] = Field(title="Task Name")
|
||||
type: str = Field(title="Task Type", description="Either txt2img or img2img")
|
||||
status: str = Field(
|
||||
title="Task Status", description="Either pending, running, done or failed"
|
||||
"pending",
|
||||
title="Task Status",
|
||||
description="Either pending, running, done or failed",
|
||||
)
|
||||
params: dict[str, Any] = Field(
|
||||
title="Task Parameters", description="The parameters of the task in JSON format"
|
||||
)
|
||||
priority: int = Field(title="Task Priority")
|
||||
priority: Optional[int] = Field(title="Task Priority")
|
||||
result: Optional[str] = Field(
|
||||
title="Task Result", description="The result of the task in JSON format"
|
||||
)
|
||||
|
|
@ -53,12 +55,6 @@ class TaskModel(BaseModel):
|
|||
default=None,
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
# custom output conversion for datetime
|
||||
datetime: convert_datetime_to_iso_8601_with_z_suffix
|
||||
}
|
||||
|
||||
|
||||
class Txt2ImgApiTaskArgs(StableDiffusionTxt2ImgProcessingAPI):
|
||||
checkpoint: Optional[str] = Field(
|
||||
|
|
@ -67,9 +63,7 @@ class Txt2ImgApiTaskArgs(StableDiffusionTxt2ImgProcessingAPI):
|
|||
description="Custom checkpoint hash. If not specified, the latest checkpoint will be used.",
|
||||
)
|
||||
sampler_index: Optional[str] = Field(
|
||||
sd_samplers.samplers[0].name,
|
||||
title="Sampler name",
|
||||
alias="sampler_name"
|
||||
sd_samplers.samplers[0].name, title="Sampler name", alias="sampler_name"
|
||||
)
|
||||
|
||||
class Config(StableDiffusionTxt2ImgProcessingAPI.__config__):
|
||||
|
|
@ -87,9 +81,7 @@ class Img2ImgApiTaskArgs(StableDiffusionImg2ImgProcessingAPI):
|
|||
description="Custom checkpoint hash. If not specified, the latest checkpoint will be used.",
|
||||
)
|
||||
sampler_index: Optional[str] = Field(
|
||||
sd_samplers.samplers[0].name,
|
||||
title="Sampler name",
|
||||
alias="sampler_name"
|
||||
sd_samplers.samplers[0].name, title="Sampler name", alias="sampler_name"
|
||||
)
|
||||
|
||||
class Config(StableDiffusionImg2ImgProcessingAPI.__config__):
|
||||
|
|
@ -116,7 +108,13 @@ class QueueStatusResponse(BaseModel):
|
|||
)
|
||||
paused: bool = Field(title="Paused", description="Whether the queue is paused")
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda dt: int(dt.timestamp() * 1e3)}
|
||||
|
||||
|
||||
class HistoryResponse(BaseModel):
|
||||
tasks: List[TaskModel] = Field(title="Tasks")
|
||||
total: int = Field(title="Task count")
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda dt: int(dt.timestamp() * 1e3)}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import base64
|
|||
import inspect
|
||||
import requests
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
from enum import Enum
|
||||
from PIL import Image, ImageOps, ImageChops, ImageEnhance, ImageFilter
|
||||
|
||||
|
|
@ -18,7 +19,7 @@ from modules.api.models import (
|
|||
StableDiffusionImg2ImgProcessingAPI,
|
||||
)
|
||||
|
||||
from .helpers import log
|
||||
from .helpers import log, get_dict_attribute
|
||||
|
||||
img2img_image_args_by_mode: dict[int, list[list[str]]] = {
|
||||
0: [["init_img"]],
|
||||
|
|
@ -29,6 +30,22 @@ img2img_image_args_by_mode: dict[int, list[list[str]]] = {
|
|||
}
|
||||
|
||||
|
||||
def get_script_by_name(
|
||||
script_name: str, is_img2img: bool = False, is_always_on: bool = False
|
||||
) -> scripts.Script:
|
||||
script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
|
||||
available_scripts = (
|
||||
script_runner.alwayson_scripts
|
||||
if is_always_on
|
||||
else script_runner.selectable_scripts
|
||||
)
|
||||
|
||||
return next(
|
||||
(s for s in available_scripts if s.title().lower() == script_name.lower()),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def load_image_from_url(url: str):
|
||||
try:
|
||||
response = requests.get(url)
|
||||
|
|
@ -39,43 +56,20 @@ def load_image_from_url(url: str):
|
|||
return None
|
||||
|
||||
|
||||
def load_image(image: str):
|
||||
if not isinstance(image, str):
|
||||
def encode_image_to_base64(image):
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image.astype("uint8"))
|
||||
elif isinstance(image, str):
|
||||
if image.startswith("http://") or image.startswith("https://"):
|
||||
image = load_image_from_url(image)
|
||||
|
||||
if not isinstance(image, Image.Image):
|
||||
return image
|
||||
|
||||
pil_image = None
|
||||
if os.path.exists(image):
|
||||
pil_image = Image.open(image)
|
||||
elif image.startswith(("http://", "https://")):
|
||||
pil_image = load_image_from_url(image)
|
||||
|
||||
return pil_image
|
||||
|
||||
|
||||
def load_image_to_nparray(image: str):
|
||||
pil_image = load_image(image)
|
||||
|
||||
return (
|
||||
np.array(pil_image).astype("uint8")
|
||||
if isinstance(pil_image, Image.Image)
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
def encode_pil_to_base64(image: Image.Image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
image.save(output_bytes, format="PNG")
|
||||
bytes_data = output_bytes.getvalue()
|
||||
return base64.b64encode(bytes_data).decode("utf-8")
|
||||
|
||||
|
||||
def load_image_to_base64(image: str):
|
||||
pil_image = load_image(image)
|
||||
|
||||
if not isinstance(pil_image, Image.Image):
|
||||
return image
|
||||
|
||||
return encode_pil_to_base64(pil_image)
|
||||
return "data:image/png;base64," + base64.b64encode(bytes_data).decode("utf-8")
|
||||
|
||||
|
||||
def __serialize_image(image):
|
||||
|
|
@ -164,7 +158,6 @@ def serialize_controlnet_args(cnet_unit):
|
|||
|
||||
|
||||
def deserialize_controlnet_args(args: dict):
|
||||
# args.pop("is_cnet", None)
|
||||
for k, v in args.items():
|
||||
if k == "image" and v is not None:
|
||||
args[k] = {
|
||||
|
|
@ -177,6 +170,24 @@ def deserialize_controlnet_args(args: dict):
|
|||
return args
|
||||
|
||||
|
||||
def map_controlnet_args_to_api_task_args(args: dict):
|
||||
if type(args).__name__ == "UiControlNetUnit":
|
||||
args = args.__dict__
|
||||
|
||||
for k, v in args.items():
|
||||
if k == "image" and v is not None:
|
||||
args[k] = {
|
||||
"image": encode_image_to_base64(v["image"]),
|
||||
"mask": encode_image_to_base64(v["mask"])
|
||||
if v.get("mask", None) is not None
|
||||
else None,
|
||||
}
|
||||
if isinstance(v, Enum):
|
||||
args[k] = v.value
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def map_ui_task_args_list_to_named_args(
|
||||
args: list, is_img2img: bool, checkpoint: str = None
|
||||
):
|
||||
|
|
@ -195,7 +206,10 @@ def map_ui_task_args_list_to_named_args(
|
|||
|
||||
sampler_index = named_args.get("sampler_index", None)
|
||||
if sampler_index is not None:
|
||||
sampler_name = sd_samplers.samplers[named_args["sampler_index"]].name
|
||||
available_samplers = (
|
||||
sd_samplers.samplers_for_img2img if is_img2img else sd_samplers.samplers
|
||||
)
|
||||
sampler_name = available_samplers[named_args["sampler_index"]].name
|
||||
named_args["sampler_name"] = sampler_name
|
||||
log.debug(f"serialize sampler index: {str(sampler_index)} as {sampler_name}")
|
||||
|
||||
|
|
@ -230,6 +244,49 @@ def map_named_args_to_ui_task_args_list(
|
|||
return args
|
||||
|
||||
|
||||
def map_script_args_list_to_named(script: scripts.Script, args: list):
|
||||
script_name = script.title().lower()
|
||||
print("script", script_name, "is alwayson", script.alwayson)
|
||||
|
||||
if script_name == "controlnet":
|
||||
for i, cnet_args in enumerate(args):
|
||||
args[i] = map_controlnet_args_to_api_task_args(cnet_args)
|
||||
|
||||
return args
|
||||
|
||||
fn = script.process if script.alwayson else script.run
|
||||
inspection = inspect.getfullargspec(fn)
|
||||
arg_names = inspection.args[2:]
|
||||
named_script_args = dict(zip(arg_names, args[: len(arg_names)]))
|
||||
if inspection.varargs is not None:
|
||||
named_script_args[inspection.varargs] = args[len(arg_names) :]
|
||||
|
||||
return named_script_args
|
||||
|
||||
|
||||
def map_named_script_args_to_list(
|
||||
script: scripts.Script, named_args: Union[dict, list]
|
||||
):
|
||||
script_name = script.title().lower()
|
||||
|
||||
if isinstance(named_args, dict):
|
||||
fn = script.process if script.alwayson else script.run
|
||||
inspection = inspect.getfullargspec(fn)
|
||||
arg_names = inspection.args[2:]
|
||||
args = [named_args.get(name, None) for name in arg_names]
|
||||
if inspection.varargs is not None:
|
||||
args.extend(named_args.get(inspection.varargs, []))
|
||||
|
||||
return args
|
||||
|
||||
if isinstance(named_args, list):
|
||||
if script_name == "controlnet":
|
||||
for i, cnet_args in enumerate(named_args):
|
||||
named_args[i] = map_controlnet_args_to_api_task_args(cnet_args)
|
||||
|
||||
return named_args
|
||||
|
||||
|
||||
def map_ui_task_args_to_api_task_args(
|
||||
named_args: dict, script_args: list, is_img2img: bool
|
||||
):
|
||||
|
|
@ -255,45 +312,50 @@ def map_ui_task_args_to_api_task_args(
|
|||
|
||||
# the logic below is copied from modules/img2img.py
|
||||
if mode == 0:
|
||||
image = api_task_args.pop("init_img").convert("RGB")
|
||||
image = api_task_args.pop("init_img")
|
||||
image = image.convert("RGB") if image else None
|
||||
mask = None
|
||||
elif mode == 1:
|
||||
image = api_task_args.pop("sketch").convert("RGB")
|
||||
image = api_task_args.pop("sketch")
|
||||
image = image.convert("RGB") if image else None
|
||||
mask = None
|
||||
elif mode == 2:
|
||||
init_img_with_mask: dict = api_task_args.pop("init_img_with_mask")
|
||||
image = init_img_with_mask.get("image").convert("RGB")
|
||||
mask = init_img_with_mask.get("mask")
|
||||
alpha_mask = (
|
||||
ImageOps.invert(image.split()[-1])
|
||||
.convert("L")
|
||||
.point(lambda x: 255 if x > 0 else 0, mode="1")
|
||||
)
|
||||
mask = ImageChops.lighter(alpha_mask, mask.convert("L")).convert("L")
|
||||
init_img_with_mask: dict = api_task_args.pop("init_img_with_mask") or {}
|
||||
image = init_img_with_mask.get("image", None)
|
||||
image = image.convert("RGB") if image else None
|
||||
mask = init_img_with_mask.get("mask", None)
|
||||
if mask:
|
||||
alpha_mask = (
|
||||
ImageOps.invert(image.split()[-1])
|
||||
.convert("L")
|
||||
.point(lambda x: 255 if x > 0 else 0, mode="1")
|
||||
)
|
||||
mask = ImageChops.lighter(alpha_mask, mask.convert("L")).convert("L")
|
||||
elif mode == 3:
|
||||
image = api_task_args.pop("inpaint_color_sketch")
|
||||
orig = api_task_args.pop("inpaint_color_sketch_orig") or image
|
||||
mask_alpha = api_task_args.pop("mask_alpha", 0)
|
||||
mask_blur = api_task_args.get("mask_blur", 4)
|
||||
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
||||
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
||||
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
||||
blur = ImageFilter.GaussianBlur(mask_blur)
|
||||
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||
image = image.convert("RGB")
|
||||
if image is not None:
|
||||
mask_alpha = api_task_args.pop("mask_alpha", 0)
|
||||
mask_blur = api_task_args.get("mask_blur", 4)
|
||||
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
||||
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
||||
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
||||
blur = ImageFilter.GaussianBlur(mask_blur)
|
||||
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||
image = image.convert("RGB")
|
||||
elif mode == 4:
|
||||
image = api_task_args.pop("init_img_inpaint")
|
||||
mask = api_task_args.pop("init_mask_inpaint")
|
||||
else:
|
||||
raise Exception(f"Batch mode is not supported yet")
|
||||
|
||||
image = ImageOps.exif_transpose(image)
|
||||
api_task_args["init_images"] = [encode_pil_to_base64(image)]
|
||||
api_task_args["mask"] = encode_pil_to_base64(mask) if mask is not None else None
|
||||
image = ImageOps.exif_transpose(image) if image else None
|
||||
api_task_args["init_images"] = [encode_image_to_base64(image)] if image else []
|
||||
api_task_args["mask"] = encode_image_to_base64(mask) if mask else None
|
||||
|
||||
selected_scale_tab = api_task_args.pop("selected_scale_tab", 0)
|
||||
scale_by = api_task_args.pop("scale_by", 1)
|
||||
if selected_scale_tab == 1:
|
||||
scale_by = api_task_args.get("scale_by", 1)
|
||||
if selected_scale_tab == 1 and image:
|
||||
api_task_args["width"] = int(image.width * scale_by)
|
||||
api_task_args["height"] = int(image.height * scale_by)
|
||||
else:
|
||||
|
|
@ -311,23 +373,27 @@ def map_ui_task_args_to_api_task_args(
|
|||
api_task_args["script_name"] = None
|
||||
api_task_args["script_args"] = []
|
||||
else:
|
||||
script = script_runner.selectable_scripts[script_id - 1]
|
||||
api_task_args["script_name"] = script.title.lower()
|
||||
api_task_args["script_args"] = script_args[script.args_from : script.args_to]
|
||||
script: scripts.Script = script_runner.selectable_scripts[script_id - 1]
|
||||
api_task_args["script_name"] = script.title().lower()
|
||||
current_script_args = script_args[script.args_from : script.args_to]
|
||||
api_task_args["script_args"] = map_script_args_list_to_named(
|
||||
script, current_script_args
|
||||
)
|
||||
|
||||
# alwayson scripts
|
||||
alwayson_scripts = api_task_args.get("alwayson_scripts", None)
|
||||
if not alwayson_scripts or not isinstance(alwayson_scripts, dict):
|
||||
alwayson_scripts = {}
|
||||
api_task_args["alwayson_scripts"] = alwayson_scripts
|
||||
if not alwayson_scripts:
|
||||
api_task_args["alwayson_scripts"] = {}
|
||||
alwayson_scripts = api_task_args["alwayson_scripts"]
|
||||
|
||||
for script in script_runner.alwayson_scripts:
|
||||
alwayson_script_args = script_args[script.args_from : script.args_to]
|
||||
if script.title.lower() == "controlnet":
|
||||
for i, cnet_args in enumerate(alwayson_script_args):
|
||||
alwayson_script_args[i] = serialize_controlnet_args(cnet_args)
|
||||
|
||||
alwayson_scripts[script.title.lower()] = {"args": alwayson_script_args}
|
||||
script_name = script.title().lower()
|
||||
if script_name != "agent scheduler":
|
||||
named_script_args = map_script_args_list_to_named(
|
||||
script, alwayson_script_args
|
||||
)
|
||||
alwayson_scripts[script_name] = {"args": named_script_args}
|
||||
|
||||
return api_task_args
|
||||
|
||||
|
|
@ -336,8 +402,33 @@ def serialize_api_task_args(
|
|||
params: dict,
|
||||
is_img2img: bool,
|
||||
checkpoint: str = None,
|
||||
controlnet_args: list[dict] = None,
|
||||
):
|
||||
# handle named script args
|
||||
script_name = params.get("script_name", None)
|
||||
if script_name is not None:
|
||||
script = get_script_by_name(script_name, is_img2img)
|
||||
if script is None:
|
||||
raise Exception(f"Not found script {script_name}")
|
||||
|
||||
script_args = params.get("script_args", {})
|
||||
params["script_args"] = map_named_script_args_to_list(script, script_args)
|
||||
|
||||
# handle named alwayson script args
|
||||
alwayson_scripts = get_dict_attribute(params, "alwayson_scripts", {})
|
||||
valid_alwayson_scripts = {}
|
||||
script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
|
||||
for script in script_runner.alwayson_scripts:
|
||||
script_name = script.title().lower()
|
||||
if script_name == "agent scheduler":
|
||||
continue
|
||||
|
||||
script_args = get_dict_attribute(alwayson_scripts, f"{script_name}.args", None)
|
||||
if script_args:
|
||||
arg_list = map_named_script_args_to_list(script, script_args)
|
||||
valid_alwayson_scripts[script_name] = {"args": arg_list}
|
||||
|
||||
params["alwayson_scripts"] = valid_alwayson_scripts
|
||||
|
||||
args = (
|
||||
StableDiffusionImg2ImgProcessingAPI(**params)
|
||||
if is_img2img
|
||||
|
|
@ -350,48 +441,19 @@ def serialize_api_task_args(
|
|||
if checkpoint is not None:
|
||||
checkpoint_info: CheckpointInfo = get_closet_checkpoint_match(checkpoint)
|
||||
if not checkpoint_info:
|
||||
log.warn(
|
||||
f"[AgentScheduler] No checkpoint found for model hash {checkpoint}"
|
||||
)
|
||||
return
|
||||
raise Exception(f"No checkpoint found for model hash {checkpoint}")
|
||||
args.override_settings["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
# load images from url or file if needed
|
||||
if is_img2img:
|
||||
init_images = args.init_images
|
||||
if len(init_images) == 0:
|
||||
raise Exception("At least one init image is required")
|
||||
|
||||
for i, image in enumerate(init_images):
|
||||
init_images[i] = load_image_to_base64(image)
|
||||
init_images[i] = encode_image_to_base64(image)
|
||||
|
||||
args.mask = load_image_to_base64(args.mask)
|
||||
|
||||
# handle custom controlnet args
|
||||
if controlnet_args is not None:
|
||||
if args.alwayson_scripts is None:
|
||||
args.alwayson_scripts = {}
|
||||
|
||||
controlnets = []
|
||||
for cnet in controlnet_args:
|
||||
enabled = cnet.get("enabled", True)
|
||||
cnet_image = cnet.get("image", None)
|
||||
|
||||
if not enabled:
|
||||
continue
|
||||
if not isinstance(cnet_image, dict):
|
||||
log.error(f"[AgentScheduler] Controlnet image is required")
|
||||
continue
|
||||
|
||||
image = cnet_image.get("image", None)
|
||||
mask = cnet_image.get("mask", None)
|
||||
if image is None:
|
||||
log.error(f"[AgentScheduler] Controlnet image is required")
|
||||
continue
|
||||
|
||||
# load controlnet images from url or file if needed
|
||||
cnet_image["image"] = load_image_to_base64(image)
|
||||
cnet_image["mask"] = load_image_to_base64(mask)
|
||||
controlnets.append(cnet)
|
||||
|
||||
if len(controlnets) > 0:
|
||||
args.alwayson_scripts["controlnet"] = {"args": controlnets}
|
||||
args.mask = encode_image_to_base64(args.mask)
|
||||
args.batch_size = len(init_images)
|
||||
|
||||
return args.dict()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
|
|
@ -7,6 +8,7 @@ from pydantic import BaseModel
|
|||
from datetime import datetime, timedelta
|
||||
from typing import Any, Callable, Union, Optional
|
||||
from fastapi import FastAPI
|
||||
from PIL import Image
|
||||
|
||||
from modules import progress, shared, script_callbacks
|
||||
from modules.call_queue import queue_lock, wrap_gradio_call
|
||||
|
|
@ -19,12 +21,19 @@ from modules.api.models import (
|
|||
)
|
||||
|
||||
from .db import TaskStatus, Task, task_manager
|
||||
from .helpers import log, detect_control_net, get_component_by_elem_id
|
||||
from .helpers import (
|
||||
log,
|
||||
detect_control_net,
|
||||
get_component_by_elem_id,
|
||||
get_dict_attribute,
|
||||
)
|
||||
from .task_helpers import (
|
||||
encode_image_to_base64,
|
||||
serialize_img2img_image_args,
|
||||
deserialize_img2img_image_args,
|
||||
serialize_controlnet_args,
|
||||
deserialize_controlnet_args,
|
||||
serialize_api_task_args,
|
||||
map_ui_task_args_list_to_named_args,
|
||||
map_named_args_to_ui_task_args_list,
|
||||
)
|
||||
|
|
@ -119,11 +128,18 @@ class TaskRunner:
|
|||
)
|
||||
|
||||
def __serialize_api_task_args(
|
||||
self, is_img2img: bool, script_args: list = [], **named_args
|
||||
self,
|
||||
is_img2img: bool,
|
||||
script_args: list = [],
|
||||
checkpoint: str = None,
|
||||
**api_args,
|
||||
):
|
||||
# serialization steps are done in task_helpers.register_api_task
|
||||
override_settings = named_args.get("override_settings", {})
|
||||
checkpoint = override_settings.get("sd_model_checkpoint", None)
|
||||
named_args = serialize_api_task_args(
|
||||
api_args, is_img2img, checkpoint=checkpoint
|
||||
)
|
||||
checkpoint = get_dict_attribute(
|
||||
named_args, "override_settings.sd_model_checkpoint", None
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
|
|
@ -149,13 +165,19 @@ class TaskRunner:
|
|||
script_args[i] = deserialize_controlnet_args(arg)
|
||||
|
||||
def __deserialize_api_task_args(self, is_img2img: bool, named_args: dict):
|
||||
# API task use base64 images as input, no need to deserialize
|
||||
pass
|
||||
# load images from disk
|
||||
if is_img2img:
|
||||
init_images = named_args.get("init_images")
|
||||
for i, img in enumerate(init_images):
|
||||
if isinstance(img, str) and os.path.isfile(img):
|
||||
print("loading image", img)
|
||||
image = Image.open(img)
|
||||
init_images[i] = encode_image_to_base64(image)
|
||||
|
||||
def parse_task_args(
|
||||
self, params: str, script_params: bytes = None, deserialization: bool = True
|
||||
):
|
||||
parsed: dict[str, Any] = json.loads(params)
|
||||
named_args.update({"save_images": True, "send_images": False})
|
||||
|
||||
def parse_task_args(self, task: Task, deserialization: bool = True):
|
||||
parsed: dict[str, Any] = json.loads(task.params)
|
||||
|
||||
is_ui = parsed.get("is_ui", True)
|
||||
is_img2img = parsed.get("is_img2img", None)
|
||||
|
|
@ -198,13 +220,18 @@ class TaskRunner:
|
|||
self.__total_pending_tasks += 1
|
||||
|
||||
def register_api_task(
|
||||
self, task_id: str, api_task_id: str, is_img2img: bool, args: dict
|
||||
self,
|
||||
task_id: str,
|
||||
api_task_id: str,
|
||||
is_img2img: bool,
|
||||
args: dict,
|
||||
checkpoint: str = None,
|
||||
):
|
||||
progress.add_task_to_queue(task_id)
|
||||
|
||||
args = args.copy()
|
||||
args.update({"save_images": True, "send_images": False})
|
||||
params = self.__serialize_api_task_args(is_img2img, **args)
|
||||
params = self.__serialize_api_task_args(
|
||||
is_img2img, checkpoint=checkpoint, **args
|
||||
)
|
||||
|
||||
task_type = "img2img" if is_img2img else "txt2img"
|
||||
task_manager.add_task(
|
||||
|
|
@ -216,7 +243,7 @@ class TaskRunner:
|
|||
)
|
||||
self.__total_pending_tasks += 1
|
||||
|
||||
def execute_task(self, task: Task, get_next_task: Callable):
|
||||
def execute_task(self, task: Task, get_next_task: Callable[[], Task]):
|
||||
while True:
|
||||
if self.dispose:
|
||||
break
|
||||
|
|
@ -226,10 +253,7 @@ class TaskRunner:
|
|||
is_img2img = task.type == "img2img"
|
||||
log.info(f"[AgentScheduler] Executing task {task_id}")
|
||||
|
||||
task_args = self.parse_task_args(
|
||||
task.params,
|
||||
task.script_params,
|
||||
)
|
||||
task_args = self.parse_task_args(task)
|
||||
task_meta = {
|
||||
"is_img2img": is_img2img,
|
||||
"is_ui": task_args.is_ui,
|
||||
|
|
@ -439,7 +463,9 @@ class TaskRunner:
|
|||
self.__run_callbacks("task_cleared")
|
||||
|
||||
def __on_image_saved(self, data: script_callbacks.ImageSaveParams):
|
||||
self.__saved_images_path.append((data.filename, data.pnginfo.get("parameters", "")))
|
||||
self.__saved_images_path.append(
|
||||
(data.filename, data.pnginfo.get("parameters", ""))
|
||||
)
|
||||
|
||||
def on_task_registered(self, callback: Callable):
|
||||
"""Callback when a task is registered
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -31,6 +31,9 @@ initialized = False
|
|||
checkpoint_current = "Current Checkpoint"
|
||||
checkpoint_runtime = "Runtime Checkpoint"
|
||||
|
||||
ui_placement_as_tab = "As a tab"
|
||||
ui_placement_append_to_main = "Append to main UI"
|
||||
|
||||
placement_under_generate = "Under Generate button"
|
||||
placement_between_prompt_and_generate = "Between Prompt and Generate button"
|
||||
|
||||
|
|
@ -48,7 +51,7 @@ class Script(scripts.Script):
|
|||
return "Agent Scheduler"
|
||||
|
||||
def show(self, is_img2img):
|
||||
return True
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def on_checkpoint_changed(self, checkpoint):
|
||||
self.checkpoint_override = checkpoint
|
||||
|
|
@ -409,15 +412,39 @@ def on_ui_settings():
|
|||
section=section,
|
||||
),
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"queue_ui_placement",
|
||||
shared.OptionInfo(
|
||||
ui_placement_as_tab,
|
||||
"Task queue UI placement",
|
||||
gr.Radio,
|
||||
lambda: {
|
||||
"choices": [
|
||||
ui_placement_as_tab,
|
||||
ui_placement_append_to_main,
|
||||
]
|
||||
},
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def on_app_started(block, app):
|
||||
def on_app_started(block: gr.Blocks, app):
|
||||
global task_runner
|
||||
task_runner = get_instance(block)
|
||||
task_runner.execute_pending_tasks_threading()
|
||||
regsiter_apis(app, task_runner)
|
||||
|
||||
if (
|
||||
getattr(shared.opts, "queue_ui_placement", "") == ui_placement_append_to_main
|
||||
and block
|
||||
):
|
||||
with block:
|
||||
with block.children[1]:
|
||||
on_ui_tab()
|
||||
|
||||
if getattr(shared.opts, "queue_ui_placement", "") != ui_placement_append_to_main:
|
||||
script_callbacks.on_ui_tabs(on_ui_tab)
|
||||
|
||||
script_callbacks.on_ui_tabs(on_ui_tab)
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
script_callbacks.on_app_started(on_app_started)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
dist
|
||||
.eslintrc.cjs
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
module.exports = {
|
||||
root: true,
|
||||
env: { browser: true, es2020: true },
|
||||
extends: [
|
||||
'eslint:recommended',
|
||||
|
|
@ -7,8 +8,34 @@ module.exports = {
|
|||
],
|
||||
parser: '@typescript-eslint/parser',
|
||||
parserOptions: { ecmaVersion: 'latest', sourceType: 'module' },
|
||||
plugins: ['react-refresh'],
|
||||
plugins: ['react-refresh', 'simple-import-sort'],
|
||||
rules: {
|
||||
'react-refresh/only-export-components': 'warn',
|
||||
'simple-import-sort/imports': [
|
||||
'error',
|
||||
{
|
||||
groups: [
|
||||
// Side effect imports.
|
||||
['^\\u0000'],
|
||||
// Node.js builtins.
|
||||
[`^(${require('module').builtinModules.join('|')})(/|$)`],
|
||||
// Packages. `react` related packages come first.
|
||||
['^react', '^\\w', '^@\\w'],
|
||||
// Type
|
||||
[`^(@@types)(/.*|$)`],
|
||||
// Internal packages.
|
||||
[
|
||||
`^(~)(/.*|$)`,
|
||||
],
|
||||
// Parent imports. Put `..` last.
|
||||
['^\\.\\.(?!/?$)', '^\\.\\./?$'],
|
||||
// Other relative imports. Put same-folder imports and `.` last.
|
||||
['^\\./(?=.*/)(?!/?$)', '^\\.(?!/?$)', '^\\./?$'],
|
||||
// Style imports.
|
||||
['^.+\\.s?css$'],
|
||||
],
|
||||
},
|
||||
],
|
||||
'@typescript-eslint/no-explicit-any': 'off'
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,6 +35,10 @@
|
|||
|
||||
/* ========================================================================= */
|
||||
|
||||
#tabs > #agent_scheduler_tabs {
|
||||
margin-top: var(--layout-gap);
|
||||
}
|
||||
|
||||
#agent_scheduler_pending_tasks_wrapper,
|
||||
#agent_scheduler_history_wrapper {
|
||||
border: none;
|
||||
|
|
|
|||
|
|
@ -1,34 +1,36 @@
|
|||
/* eslint-disable @typescript-eslint/no-non-null-assertion */
|
||||
|
||||
import { Grid, GridOptions } from 'ag-grid-community';
|
||||
import { Notyf } from 'notyf';
|
||||
|
||||
import bookmark from '../assets/icons/bookmark.svg?raw';
|
||||
import bookmarked from '../assets/icons/bookmark-filled.svg?raw';
|
||||
import cancelIcon from '../assets/icons/cancel.svg?raw';
|
||||
import deleteIcon from '../assets/icons/delete.svg?raw';
|
||||
import playIcon from '../assets/icons/play.svg?raw';
|
||||
import rotateIcon from '../assets/icons/rotate.svg?raw';
|
||||
import searchIcon from '../assets/icons/search.svg?raw';
|
||||
import { debounce } from '../utils/debounce';
|
||||
import { extractArgs } from '../utils/extract-args';
|
||||
import { formatDate } from '../utils/format-date';
|
||||
|
||||
import { createHistoryTasksStore } from './stores/history.store';
|
||||
import { createPendingTasksStore } from './stores/pending.store';
|
||||
import { createSharedStore } from './stores/shared.store';
|
||||
import { ProgressResponse, ResponseStatus, Task, TaskStatus } from './types';
|
||||
|
||||
import 'ag-grid-community/styles/ag-grid.css';
|
||||
import 'ag-grid-community/styles/ag-theme-alpine.css';
|
||||
import 'notyf/notyf.min.css';
|
||||
import './index.scss';
|
||||
|
||||
import { createPendingTasksStore } from './stores/pending.store';
|
||||
import { ProgressResponse, ResponseStatus, Task, TaskStatus } from './types';
|
||||
import { debounce } from '../utils/debounce';
|
||||
import { extractArgs } from '../utils/extract-args';
|
||||
import { createHistoryTasksStore } from './stores/history.store';
|
||||
import { createSharedStore } from './stores/shared.store';
|
||||
|
||||
import deleteIcon from '../assets/icons/delete.svg?raw';
|
||||
import cancelIcon from '../assets/icons/cancel.svg?raw';
|
||||
import searchIcon from '../assets/icons/search.svg?raw';
|
||||
import playIcon from '../assets/icons/play.svg?raw';
|
||||
import rotateIcon from '../assets/icons/rotate.svg?raw';
|
||||
import bookmark from '../assets/icons/bookmark.svg?raw';
|
||||
import bookmarked from '../assets/icons/bookmark-filled.svg?raw';
|
||||
|
||||
const notyf = new Notyf();
|
||||
|
||||
declare global {
|
||||
var country: string;
|
||||
function gradioApp(): HTMLElement;
|
||||
function randomId(): string;
|
||||
function get_tab_index(name: string): number;
|
||||
function create_submit_args(args: IArguments): any[];
|
||||
function create_submit_args(args: any[]): any[];
|
||||
function requestProgress(
|
||||
id: string,
|
||||
progressContainer: HTMLElement,
|
||||
|
|
@ -37,12 +39,16 @@ declare global {
|
|||
onProgress?: (res: ProgressResponse) => void,
|
||||
): void;
|
||||
function onUiLoaded(callback: () => void): void;
|
||||
function submit_enqueue(): any[];
|
||||
function submit_enqueue_img2img(): any[];
|
||||
function notify(response: ResponseStatus): void;
|
||||
function submit(...args: any[]): any[];
|
||||
function submit_img2img(...args: any[]): any[];
|
||||
function submit_enqueue(...args: any[]): any[];
|
||||
function submit_enqueue_img2img(...args: any[]): any[];
|
||||
function agent_scheduler_status_filter_changed(value: string): void;
|
||||
}
|
||||
|
||||
const sharedStore = createSharedStore({
|
||||
uiAsTab: true,
|
||||
selectedTab: 'pending',
|
||||
});
|
||||
|
||||
|
|
@ -98,7 +104,8 @@ const sharedGridOptions: GridOptions<Task> = {
|
|||
{
|
||||
field: 'params.prompt',
|
||||
headerName: 'Prompt',
|
||||
minWidth: 400,
|
||||
minWidth: 200,
|
||||
maxWidth: 400,
|
||||
autoHeight: true,
|
||||
wrapText: true,
|
||||
cellStyle: { 'line-height': '24px', 'padding-top': '8px', 'padding-bottom': '8px' },
|
||||
|
|
@ -106,7 +113,8 @@ const sharedGridOptions: GridOptions<Task> = {
|
|||
{
|
||||
field: 'params.negative_prompt',
|
||||
headerName: 'Negative Prompt',
|
||||
minWidth: 400,
|
||||
minWidth: 200,
|
||||
maxWidth: 400,
|
||||
autoHeight: true,
|
||||
wrapText: true,
|
||||
cellStyle: { 'line-height': '24px', 'padding-top': '8px', 'padding-bottom': '8px' },
|
||||
|
|
@ -143,7 +151,8 @@ const sharedGridOptions: GridOptions<Task> = {
|
|||
headerName: 'Size',
|
||||
minWidth: 110,
|
||||
maxWidth: 110,
|
||||
valueGetter: ({ data }) => (data ? `${data.params.width}x${data.params.height}` : ''),
|
||||
valueGetter: ({ data }) =>
|
||||
data?.params?.width ? `${data.params.width}x${data.params.height}` : '',
|
||||
},
|
||||
{
|
||||
field: 'params.batch',
|
||||
|
|
@ -151,11 +160,22 @@ const sharedGridOptions: GridOptions<Task> = {
|
|||
minWidth: 100,
|
||||
maxWidth: 100,
|
||||
valueGetter: ({ data }) =>
|
||||
data ? `${data.params.n_iter}x${data.params.batch_size}` : '1x1',
|
||||
data?.params?.n_iter ? `${data.params.n_iter}x${data.params.batch_size}` : '1x1',
|
||||
},
|
||||
],
|
||||
},
|
||||
{ field: 'created_at', headerName: 'Date', minWidth: 200 },
|
||||
{
|
||||
field: 'created_at',
|
||||
headerName: 'Queued At',
|
||||
minWidth: 170,
|
||||
valueFormatter: ({ value }) => value && formatDate(new Date(value)),
|
||||
},
|
||||
{
|
||||
field: 'updated_at',
|
||||
headerName: 'Updated At',
|
||||
minWidth: 170,
|
||||
valueFormatter: ({ value }) => value && formatDate(new Date(value)),
|
||||
},
|
||||
],
|
||||
|
||||
getRowId: ({ data }) => data.id,
|
||||
|
|
@ -209,6 +229,8 @@ function notify(response: ResponseStatus) {
|
|||
}
|
||||
}
|
||||
|
||||
window.notify = notify;
|
||||
|
||||
function showTaskProgress(task_id: string, callback: () => void) {
|
||||
const args = extractArgs(requestProgress);
|
||||
|
||||
|
|
@ -246,6 +268,16 @@ function showTaskProgress(task_id: string, callback: () => void) {
|
|||
}
|
||||
|
||||
function initTabChangeHandler() {
|
||||
sharedStore.subscribe((curr, prev) => {
|
||||
if (!curr.uiAsTab || curr.selectedTab !== prev.selectedTab) {
|
||||
if (curr.selectedTab === 'pending') {
|
||||
pendingStore.refresh();
|
||||
} else {
|
||||
historyStore.refresh();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// watch for tab activation
|
||||
const observer = new MutationObserver(function (mutationsList) {
|
||||
mutationsList.forEach((styleChange) => {
|
||||
|
|
@ -260,15 +292,19 @@ function initTabChangeHandler() {
|
|||
historyStore.refresh();
|
||||
}
|
||||
} else if (tab.id === 'agent_scheduler_pending_tasks_tab') {
|
||||
sharedStore.selectSelectedTab('pending');
|
||||
pendingStore.refresh();
|
||||
sharedStore.setSelectedTab('pending');
|
||||
} else if (tab.id === 'agent_scheduler_history_tab') {
|
||||
sharedStore.selectSelectedTab('history');
|
||||
historyStore.refresh();
|
||||
sharedStore.setSelectedTab('history');
|
||||
}
|
||||
});
|
||||
});
|
||||
observer.observe(document.getElementById('tab_agent_scheduler')!, { attributeFilter: ['style'] });
|
||||
if (document.getElementById('tab_agent_scheduler')) {
|
||||
observer.observe(document.getElementById('tab_agent_scheduler')!, {
|
||||
attributeFilter: ['style'],
|
||||
});
|
||||
} else {
|
||||
sharedStore.setState({ uiAsTab: false });
|
||||
}
|
||||
observer.observe(document.getElementById('agent_scheduler_pending_tasks_tab')!, {
|
||||
attributeFilter: ['style'],
|
||||
});
|
||||
|
|
@ -280,33 +316,38 @@ function initTabChangeHandler() {
|
|||
function initPendingTab() {
|
||||
const store = pendingStore;
|
||||
|
||||
window.submit_enqueue = function submit_enqueue() {
|
||||
var id = randomId();
|
||||
var res = create_submit_args(arguments);
|
||||
res[0] = id;
|
||||
window.submit_enqueue = function submit_enqueue(...args) {
|
||||
const res = window.submit(...args);
|
||||
|
||||
const btnEnqueue = document.querySelector('#txt2img_enqueue');
|
||||
if (btnEnqueue) {
|
||||
btnEnqueue.innerHTML = 'Queued';
|
||||
setTimeout(() => {
|
||||
btnEnqueue.innerHTML = 'Enqueue';
|
||||
if (!sharedStore.getState().uiAsTab) {
|
||||
if (sharedStore.getState().selectedTab === 'pending') {
|
||||
pendingStore.refresh();
|
||||
}
|
||||
}
|
||||
}, 1000);
|
||||
}
|
||||
|
||||
return res;
|
||||
};
|
||||
|
||||
window.submit_enqueue_img2img = function submit_enqueue_img2img() {
|
||||
var id = randomId();
|
||||
var res = create_submit_args(arguments);
|
||||
res[0] = id;
|
||||
res[1] = get_tab_index('mode_img2img');
|
||||
window.submit_enqueue_img2img = function submit_enqueue_img2img(...args) {
|
||||
const res = window.submit_img2img(...args);
|
||||
|
||||
const btnEnqueue = document.querySelector('#img2img_enqueue');
|
||||
if (btnEnqueue) {
|
||||
btnEnqueue.innerHTML = 'Queued';
|
||||
setTimeout(() => {
|
||||
btnEnqueue.innerHTML = 'Enqueue';
|
||||
if (!sharedStore.getState().uiAsTab) {
|
||||
if (sharedStore.getState().selectedTab === 'pending') {
|
||||
pendingStore.refresh();
|
||||
}
|
||||
}
|
||||
}, 1000);
|
||||
}
|
||||
|
||||
|
|
@ -406,7 +447,22 @@ function initPendingTab() {
|
|||
},
|
||||
},
|
||||
],
|
||||
onGridReady: ({ api }) => {
|
||||
onColumnMoved({ columnApi }) {
|
||||
const colState = columnApi.getColumnState();
|
||||
const colStateStr = JSON.stringify(colState);
|
||||
localStorage.setItem('agent_scheduler:queue_col_state', colStateStr);
|
||||
},
|
||||
onSortChanged({ columnApi }) {
|
||||
const colState = columnApi.getColumnState();
|
||||
const colStateStr = JSON.stringify(colState);
|
||||
localStorage.setItem('agent_scheduler:queue_col_state', colStateStr);
|
||||
},
|
||||
onColumnResized({ columnApi }) {
|
||||
const colState = columnApi.getColumnState();
|
||||
const colStateStr = JSON.stringify(colState);
|
||||
localStorage.setItem('agent_scheduler:queue_col_state', colStateStr);
|
||||
},
|
||||
onGridReady: ({ api, columnApi }) => {
|
||||
// init quick search input
|
||||
const searchContainer = initSearchInput('#agent_scheduler_action_search');
|
||||
const searchInput: HTMLInputElement = searchContainer.querySelector('input.ts-search-input')!;
|
||||
|
|
@ -427,8 +483,15 @@ function initPendingTab() {
|
|||
}
|
||||
}
|
||||
|
||||
api.sizeColumnsToFit();
|
||||
columnApi.autoSizeAllColumns();
|
||||
});
|
||||
|
||||
// restore col state
|
||||
const colStateStr = localStorage.getItem('agent_scheduler:queue_col_state');
|
||||
if (colStateStr) {
|
||||
const colState = JSON.parse(colStateStr);
|
||||
columnApi.applyColumnState({ state: colState, applyOrder: true });
|
||||
}
|
||||
},
|
||||
onRowDragEnd: ({ api, node, overNode }) => {
|
||||
const id = node.data?.id;
|
||||
|
|
@ -581,7 +644,22 @@ function initHistoryTab() {
|
|||
],
|
||||
rowSelection: 'single',
|
||||
suppressRowDeselection: true,
|
||||
onGridReady: ({ api }) => {
|
||||
onColumnMoved({ columnApi }) {
|
||||
const colState = columnApi.getColumnState();
|
||||
const colStateStr = JSON.stringify(colState);
|
||||
localStorage.setItem('agent_scheduler:history_col_state', colStateStr);
|
||||
},
|
||||
onSortChanged({ columnApi }) {
|
||||
const colState = columnApi.getColumnState();
|
||||
const colStateStr = JSON.stringify(colState);
|
||||
localStorage.setItem('agent_scheduler:history_col_state', colStateStr);
|
||||
},
|
||||
onColumnResized({ columnApi }) {
|
||||
const colState = columnApi.getColumnState();
|
||||
const colStateStr = JSON.stringify(colState);
|
||||
localStorage.setItem('agent_scheduler:history_col_state', colStateStr);
|
||||
},
|
||||
onGridReady: ({ api, columnApi }) => {
|
||||
// init quick search input
|
||||
const searchContainer = initSearchInput('#agent_scheduler_action_search_history');
|
||||
const searchInput: HTMLInputElement = searchContainer.querySelector('input.ts-search-input')!;
|
||||
|
|
@ -594,8 +672,15 @@ function initHistoryTab() {
|
|||
|
||||
store.subscribe((state) => {
|
||||
api.setRowData(state.tasks);
|
||||
api.sizeColumnsToFit();
|
||||
columnApi.autoSizeAllColumns();
|
||||
});
|
||||
|
||||
// restore col state
|
||||
const colStateStr = localStorage.getItem('agent_scheduler:history_col_state');
|
||||
if (colStateStr) {
|
||||
const colState = JSON.parse(colStateStr);
|
||||
columnApi.applyColumnState({ state: colState, applyOrder: true });
|
||||
}
|
||||
},
|
||||
onSelectionChanged: (e) => {
|
||||
const [selected] = e.api.getSelectedRows();
|
||||
|
|
@ -619,7 +704,7 @@ let agentSchedulerInitialized = false;
|
|||
|
||||
onUiLoaded(function initAgentScheduler() {
|
||||
// delay ui init until dom is available
|
||||
if (!document.getElementById('tab_agent_scheduler')) {
|
||||
if (!document.getElementById('agent_scheduler_tabs')) {
|
||||
setTimeout(initAgentScheduler, 500);
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@ import { createStore } from 'zustand/vanilla';
|
|||
type SelectedTab = 'history' | 'pending';
|
||||
|
||||
type SharedState = {
|
||||
uiAsTab: boolean;
|
||||
selectedTab: SelectedTab;
|
||||
};
|
||||
|
||||
type SharedActions = {
|
||||
selectSelectedTab: (tab: SelectedTab) => void;
|
||||
setSelectedTab: (tab: SelectedTab) => void;
|
||||
};
|
||||
|
||||
export const createSharedStore = (initialState: SharedState) => {
|
||||
|
|
@ -15,7 +16,7 @@ export const createSharedStore = (initialState: SharedState) => {
|
|||
const { getState, setState, subscribe } = store;
|
||||
|
||||
const actions: SharedActions = {
|
||||
selectSelectedTab: (tab: SelectedTab) => {
|
||||
setSelectedTab: (tab: SelectedTab) => {
|
||||
setState({ selectedTab: tab });
|
||||
},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -35,20 +35,16 @@
|
|||
}
|
||||
}
|
||||
|
||||
.ag-cell.task-running .ag-cell-wrapper {
|
||||
.ag-cell.task-running {
|
||||
@apply !text-blue-500;
|
||||
animation: 1s blink ease infinite;
|
||||
}
|
||||
|
||||
/* .ag-cell.task-done .ag-cell-wrapper {
|
||||
@apply !text-green-500;
|
||||
} */
|
||||
|
||||
.ag-cell.task-failed .ag-cell-wrapper {
|
||||
.ag-cell.task-failed {
|
||||
@apply !text-red-500;
|
||||
}
|
||||
|
||||
.ag-cell.task-interrupted .ag-cell-wrapper {
|
||||
.ag-cell.task-interrupted {
|
||||
@apply !text-gray-400;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
export const debounce = (fn: Function, ms = 300) => {
|
||||
export const debounce = (fn: (...args: any[]) => any, ms = 300) => {
|
||||
let timeoutId: ReturnType<typeof setTimeout>;
|
||||
return function (this: any, ...args: any[]) {
|
||||
clearTimeout(timeoutId);
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
export const extractArgs = (func: Function) => {
|
||||
export const extractArgs = (func: (...args: any[]) => any) => {
|
||||
return (func + '')
|
||||
.replace(/[/][/].*$/gm, '') // strip single-line comments
|
||||
.replace(/\s+/g, '') // strip white space
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
export function formatDate(date: Date): string {
|
||||
const year = date.getFullYear();
|
||||
const month = (date.getMonth() + 1).toString().padStart(2, '0');
|
||||
const day = date.getDate().toString().padStart(2, '0');
|
||||
const hours = date.getHours().toString().padStart(2, '0');
|
||||
const minutes = date.getMinutes().toString().padStart(2, '0');
|
||||
const seconds = date.getSeconds().toString().padStart(2, '0');
|
||||
|
||||
return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`;
|
||||
}
|
||||
Loading…
Reference in New Issue