403 lines
16 KiB
Python
403 lines
16 KiB
Python
import json
|
|
import socket
|
|
from typing import Any, Dict, List
|
|
from urllib.error import URLError
|
|
from urllib.parse import urljoin, urlparse
|
|
from urllib.request import Request, urlopen
|
|
|
|
from krita import QObject, QThread, pyqtSignal
|
|
|
|
from .config import Config
|
|
from .defaults import (
|
|
ERR_BAD_URL,
|
|
ERR_NO_CONNECTION,
|
|
LONG_TIMEOUT,
|
|
OFFICIAL_ROUTE_PREFIX,
|
|
ROUTE_PREFIX,
|
|
SHORT_TIMEOUT,
|
|
STATE_DONE,
|
|
STATE_READY,
|
|
STATE_URLERROR,
|
|
THREADED,
|
|
)
|
|
from .utils import bytewise_xor, fix_prompt, get_ext_args, get_ext_key, img_to_b64
|
|
|
|
# NOTE: backend queues up responses, so no explicit need to block multiple requests
|
|
# except to prevent user from spamming themselves
|
|
|
|
# TODO: tab showing all queued up requests (local plugin instance only)
|
|
|
|
|
|
def get_url(cfg: Config, route: str = ..., prefix: str = ROUTE_PREFIX):
|
|
base = cfg("base_url", str)
|
|
if not urlparse(base).scheme in {"http", "https"}:
|
|
return None
|
|
url = urljoin(base, prefix)
|
|
if route is not ...:
|
|
url = urljoin(url, route)
|
|
# print("url:", url)
|
|
return url
|
|
|
|
|
|
# krita doesn't reexport QtNetwork
|
|
class AsyncRequest(QObject):
|
|
timeout = None
|
|
finished = pyqtSignal()
|
|
result = pyqtSignal(object)
|
|
error = pyqtSignal(Exception)
|
|
|
|
def __init__(
|
|
self,
|
|
url: str,
|
|
data: Any = None,
|
|
timeout: int = ...,
|
|
method: str = ...,
|
|
headers: dict = ...,
|
|
key: str = None,
|
|
):
|
|
"""Create an AsyncRequest object.
|
|
|
|
By default, AsyncRequest has no timeout, will infer whether it is "POST"
|
|
or "GET" based on the presence of `data` and uses JSON to transmit. It
|
|
also assumes the response is JSON.
|
|
|
|
Args:
|
|
url (str): URL to request from.
|
|
data (Any, optional): Payload to send. Defaults to None.
|
|
timeout (int, optional): Timeout for request. Defaults to `...`.
|
|
method (str, optional): Which HTTP method to use. Defaults to `...`.
|
|
key (Union[str, None], Optional): Key to use for encryption/decryption. Defaults to None.
|
|
"""
|
|
super(AsyncRequest, self).__init__()
|
|
self.url = url
|
|
self.data = None if data is None else json.dumps(data).encode("utf-8")
|
|
self.headers = {} if headers is ... else headers
|
|
|
|
self.key = None
|
|
if isinstance(key, str) and key.strip() != "":
|
|
self.key = key.strip().encode("utf-8")
|
|
|
|
if self.key is not None:
|
|
self.headers["X-Encrypted-Body"] = "XOR"
|
|
if timeout is not ...:
|
|
self.timeout = timeout
|
|
if method is ...:
|
|
self.method = "GET" if data is None else "POST"
|
|
else:
|
|
self.method = method
|
|
if self.data is not None:
|
|
if self.key is not None:
|
|
# print(f"Encrypting with ${self.key}:\n{self.data}")
|
|
self.data = bytewise_xor(self.data, self.key)
|
|
# print(f"Encrypt Result:\n{self.data}")
|
|
self.headers["Content-Type"] = "application/json"
|
|
self.headers["Content-Length"] = str(len(self.data))
|
|
|
|
def run(self):
|
|
req = Request(self.url, headers=self.headers, method=self.method)
|
|
try:
|
|
with urlopen(req, self.data, self.timeout) as res:
|
|
data = res.read()
|
|
enc_type = res.getheader("X-Encrypted-Body", None)
|
|
assert enc_type in {"XOR", None}, "Unknown server encryption!"
|
|
if enc_type == "XOR":
|
|
assert self.key, f"Key needed to decrypt server response!"
|
|
# print(f"Decrypting with ${self.key}:\n{data}")
|
|
data = bytewise_xor(data, self.key)
|
|
# print(f"Decrypt Result:\n{data}")
|
|
self.result.emit(json.loads(data))
|
|
except Exception as e:
|
|
self.error.emit(e)
|
|
finally:
|
|
self.finished.emit()
|
|
|
|
@classmethod
|
|
def request(cls, *args, **kwargs):
|
|
req = cls(*args, **kwargs)
|
|
if THREADED:
|
|
thread = QThread()
|
|
# NOTE: need to keep reference to thread or it gets destroyed
|
|
req.thread = thread
|
|
req.moveToThread(thread)
|
|
thread.started.connect(req.run)
|
|
req.finished.connect(thread.quit)
|
|
# NOTE: is this a memory leak?
|
|
# For some reason, deleteLater occurs while thread is still running, resulting in crash
|
|
# req.finished.connect(req.deleteLater)
|
|
# thread.finished.connect(thread.deleteLater)
|
|
return req, lambda: thread.start()
|
|
else:
|
|
return req, lambda: req.run()
|
|
|
|
|
|
class Client(QObject):
|
|
status = pyqtSignal(str)
|
|
config_updated = pyqtSignal()
|
|
|
|
def __init__(self, cfg: Config, ext_cfg: Config):
|
|
"""It is highly dependent on config's structure to the point it writes directly to it. :/"""
|
|
super(Client, self).__init__()
|
|
self.cfg = cfg
|
|
self.ext_cfg = ext_cfg
|
|
self.short_reqs = set()
|
|
self.long_reqs = set()
|
|
# NOTE: this is a hacky workaround for detecting if backend is reachable
|
|
self.is_connected = False
|
|
|
|
def handle_api_error(self, exc: Exception):
|
|
"""Handle exceptions that can occur while interacting with the backend."""
|
|
self.is_connected = False
|
|
try:
|
|
# wtf python? socket raises an error that isnt an Exception??
|
|
if isinstance(exc, socket.timeout):
|
|
raise TimeoutError
|
|
else:
|
|
raise exc
|
|
except URLError as e:
|
|
self.status.emit(f"{STATE_URLERROR}: {e.reason}")
|
|
except TimeoutError as e:
|
|
self.status.emit(f"{STATE_URLERROR}: response timed out")
|
|
except json.JSONDecodeError as e:
|
|
self.status.emit(f"{STATE_URLERROR}: invalid JSON response")
|
|
except ValueError as e:
|
|
self.status.emit(f"{STATE_URLERROR}: Invalid backend URL")
|
|
except ConnectionError as e:
|
|
self.status.emit(f"{STATE_URLERROR}: connection error during request")
|
|
except Exception as e:
|
|
# self.status.emit(f"{STATE_URLERROR}: Unexpected Error")
|
|
# self.status.emit(str(e))
|
|
assert False, e
|
|
|
|
def post(
|
|
self, route, body, cb, base_url=..., is_long=True, ignore_no_connection=False
|
|
):
|
|
if not ignore_no_connection and not self.is_connected:
|
|
self.status.emit(ERR_NO_CONNECTION)
|
|
return
|
|
url = get_url(self.cfg, route) if base_url is ... else urljoin(base_url, route)
|
|
if not url:
|
|
self.status.emit(ERR_BAD_URL)
|
|
return
|
|
# TODO: how to cancel this? destroy the thread after sending API interrupt request?
|
|
req, start = AsyncRequest.request(
|
|
url,
|
|
body,
|
|
LONG_TIMEOUT if is_long else SHORT_TIMEOUT,
|
|
key=self.cfg("encryption_key"),
|
|
)
|
|
|
|
if is_long:
|
|
self.long_reqs.add(req)
|
|
else:
|
|
self.short_reqs.add(req)
|
|
|
|
def handler():
|
|
self.long_reqs.discard(req)
|
|
self.short_reqs.discard(req)
|
|
if is_long and len(self.long_reqs) == 0:
|
|
self.status.emit(STATE_DONE)
|
|
|
|
req.result.connect(cb)
|
|
req.error.connect(lambda e: self.handle_api_error(e))
|
|
req.finished.connect(handler)
|
|
start()
|
|
|
|
def get(self, route, cb, base_url=..., is_long=False, ignore_no_connection=False):
|
|
self.post(
|
|
route,
|
|
None,
|
|
cb,
|
|
base_url=base_url,
|
|
is_long=is_long,
|
|
ignore_no_connection=ignore_no_connection,
|
|
)
|
|
|
|
def common_params(self, has_selection):
|
|
"""Parameters nearly all the post routes share."""
|
|
tiling = self.cfg("sd_tiling", bool) and not (
|
|
self.cfg("only_full_img_tiling", bool) and has_selection
|
|
)
|
|
|
|
# its fine to stuff extra stuff here; pydantic will shave off irrelevant params
|
|
params = dict(
|
|
sd_model=self.cfg("sd_model", str),
|
|
batch_count=self.cfg("sd_batch_count", int),
|
|
batch_size=self.cfg("sd_batch_size", int),
|
|
base_size=self.cfg("sd_base_size", int),
|
|
max_size=self.cfg("sd_max_size", int),
|
|
tiling=tiling,
|
|
upscaler_name=self.cfg("upscaler_name", str),
|
|
restore_faces=self.cfg("face_restorer_model", str) != "None",
|
|
face_restorer=self.cfg("face_restorer_model", str),
|
|
codeformer_weight=self.cfg("codeformer_weight", float),
|
|
filter_nsfw=self.cfg("filter_nsfw", bool),
|
|
do_exact_steps=self.cfg("do_exact_steps", bool),
|
|
include_grid=self.cfg("include_grid", bool),
|
|
save_samples=self.cfg("save_temp_images", bool),
|
|
)
|
|
return params
|
|
|
|
def get_config(self):
|
|
def cb(obj):
|
|
try:
|
|
assert "sample_path" in obj
|
|
assert len(obj["upscalers"]) > 0
|
|
assert len(obj["samplers"]) > 0
|
|
assert len(obj["samplers_img2img"]) > 0
|
|
assert len(obj["face_restorers"]) > 0
|
|
assert len(obj["sd_models"]) > 0
|
|
assert len(obj["scripts_txt2img"]) > 0
|
|
assert len(obj["scripts_img2img"]) > 0
|
|
except:
|
|
self.status.emit(
|
|
f"{STATE_URLERROR}: incompatible response, are you running the right API?"
|
|
)
|
|
print("Invalid Response:\n", obj)
|
|
return
|
|
|
|
# replace only after verifying
|
|
self.cfg.set("sample_path", obj["sample_path"])
|
|
# NOTE: sorting these lists is risky; ivent 100% verified that I removed all reliance on indexes
|
|
self.cfg.set("upscaler_list", obj["upscalers"])
|
|
self.cfg.set("txt2img_sampler_list", obj["samplers"])
|
|
self.cfg.set("img2img_sampler_list", obj["samplers_img2img"])
|
|
self.cfg.set("inpaint_sampler_list", obj["samplers_img2img"])
|
|
self.cfg.set("txt2img_script_list", list(obj["scripts_txt2img"].keys()))
|
|
self.cfg.set("img2img_script_list", list(obj["scripts_img2img"].keys()))
|
|
self.cfg.set("inpaint_script_list", list(obj["scripts_img2img"].keys()))
|
|
self.cfg.set("face_restorer_model_list", obj["face_restorers"])
|
|
self.cfg.set("sd_model_list", obj["sd_models"])
|
|
|
|
# extension script cfg
|
|
obj["scripts_inpaint"] = obj["scripts_img2img"]
|
|
for ext_type in {"scripts_txt2img", "scripts_img2img", "scripts_inpaint"}:
|
|
metadata: Dict[str, List[dict]] = obj[ext_type]
|
|
self.ext_cfg.set(f"{ext_type}_len", len(metadata))
|
|
for ext_name, ext_meta in metadata.items():
|
|
old_val = self.ext_cfg(get_ext_key(ext_type, ext_name))
|
|
new_val = json.dumps(ext_meta)
|
|
if new_val != old_val:
|
|
self.ext_cfg.set(get_ext_key(ext_type, ext_name), new_val)
|
|
for i, opt in enumerate(ext_meta):
|
|
key = get_ext_key(ext_type, ext_name, i)
|
|
self.ext_cfg.set(key, opt["val"])
|
|
|
|
self.is_connected = True
|
|
self.status.emit(STATE_READY)
|
|
self.config_updated.emit()
|
|
|
|
self.get("config", cb, ignore_no_connection=True)
|
|
|
|
def post_txt2img(self, cb, width, height, has_selection):
|
|
params = dict(orig_width=width, orig_height=height)
|
|
if not self.cfg("just_use_yaml", bool):
|
|
seed = (
|
|
int(self.cfg("txt2img_seed", str)) # Qt casts int as 32-bit int
|
|
if not self.cfg("txt2img_seed", str).strip() == ""
|
|
else -1
|
|
)
|
|
ext_name = self.cfg("txt2img_script", str)
|
|
ext_args = get_ext_args(self.ext_cfg, "scripts_txt2img", ext_name)
|
|
params.update(self.common_params(has_selection))
|
|
params.update(
|
|
prompt=fix_prompt(self.cfg("txt2img_prompt", str)),
|
|
negative_prompt=fix_prompt(self.cfg("txt2img_negative_prompt", str)),
|
|
sampler_name=self.cfg("txt2img_sampler", str),
|
|
steps=self.cfg("txt2img_steps", int),
|
|
cfg_scale=self.cfg("txt2img_cfg_scale", float),
|
|
seed=seed,
|
|
highres_fix=self.cfg("txt2img_highres", bool),
|
|
denoising_strength=self.cfg("txt2img_denoising_strength", float),
|
|
script=ext_name,
|
|
script_args=ext_args,
|
|
)
|
|
|
|
self.post("txt2img", params, cb)
|
|
|
|
def post_img2img(self, cb, src_img, mask_img, has_selection):
|
|
params = dict(is_inpaint=False, src_img=img_to_b64(src_img))
|
|
if not self.cfg("just_use_yaml", bool):
|
|
seed = (
|
|
int(self.cfg("img2img_seed", str)) # Qt casts int as 32-bit int
|
|
if not self.cfg("img2img_seed", str).strip() == ""
|
|
else -1
|
|
)
|
|
ext_name = self.cfg("img2img_script", str)
|
|
ext_args = get_ext_args(self.ext_cfg, "scripts_img2img", ext_name)
|
|
params.update(self.common_params(has_selection))
|
|
params.update(
|
|
prompt=fix_prompt(self.cfg("img2img_prompt", str)),
|
|
negative_prompt=fix_prompt(self.cfg("img2img_negative_prompt", str)),
|
|
sampler_name=self.cfg("img2img_sampler", str),
|
|
steps=self.cfg("img2img_steps", int),
|
|
cfg_scale=self.cfg("img2img_cfg_scale", float),
|
|
denoising_strength=self.cfg("img2img_denoising_strength", float),
|
|
color_correct=self.cfg("img2img_color_correct", bool),
|
|
script=ext_name,
|
|
script_args=ext_args,
|
|
seed=seed,
|
|
)
|
|
|
|
self.post("img2img", params, cb)
|
|
|
|
def post_inpaint(self, cb, src_img, mask_img, has_selection):
|
|
assert mask_img, "Inpaint layer is needed for inpainting!"
|
|
params = dict(
|
|
is_inpaint=True, src_img=img_to_b64(src_img), mask_img=img_to_b64(mask_img)
|
|
)
|
|
if not self.cfg("just_use_yaml", bool):
|
|
seed = (
|
|
int(self.cfg("inpaint_seed", str)) # Qt casts int as 32-bit int
|
|
if not self.cfg("inpaint_seed", str).strip() == ""
|
|
else -1
|
|
)
|
|
fill = self.cfg("inpaint_fill_list", "QStringList").index(
|
|
self.cfg("inpaint_fill", str)
|
|
)
|
|
ext_name = self.cfg("inpaint_script", str)
|
|
ext_args = get_ext_args(self.ext_cfg, "scripts_inpaint", ext_name)
|
|
params.update(self.common_params(has_selection))
|
|
params.update(
|
|
prompt=fix_prompt(self.cfg("inpaint_prompt", str)),
|
|
negative_prompt=fix_prompt(self.cfg("inpaint_negative_prompt", str)),
|
|
sampler_name=self.cfg("inpaint_sampler", str),
|
|
steps=self.cfg("inpaint_steps", int),
|
|
cfg_scale=self.cfg("inpaint_cfg_scale", float),
|
|
denoising_strength=self.cfg("inpaint_denoising_strength", float),
|
|
color_correct=self.cfg("inpaint_color_correct", bool),
|
|
script=ext_name,
|
|
script_args=ext_args,
|
|
seed=seed,
|
|
invert_mask=self.cfg("inpaint_invert_mask", bool),
|
|
# mask_blur=self.cfg("inpaint_mask_blur", int),
|
|
inpainting_fill=fill,
|
|
# inpaint_full_res=self.cfg("inpaint_full_res", bool),
|
|
# inpaint_full_res_padding=self.cfg("inpaint_full_res_padding", int),
|
|
inpaint_mask_weight=self.cfg("inpaint_mask_weight", float),
|
|
include_grid=False, # it is never useful for inpaint mode
|
|
)
|
|
|
|
self.post("img2img", params, cb)
|
|
|
|
def post_upscale(self, cb, src_img):
|
|
params = (
|
|
{
|
|
"src_img": img_to_b64(src_img),
|
|
"upscaler_name": self.cfg("upscale_upscaler_name", str),
|
|
"downscale_first": self.cfg("upscale_downscale_first", bool),
|
|
}
|
|
if not self.cfg("just_use_yaml", bool)
|
|
else {"src_img": img_to_b64(src_img)}
|
|
)
|
|
self.post("upscale", params, cb)
|
|
|
|
def post_interrupt(self, cb):
|
|
# get official API url
|
|
url = get_url(self.cfg, prefix=OFFICIAL_ROUTE_PREFIX)
|
|
self.post("interrupt", {}, cb, base_url=url)
|
|
|
|
def get_progress(self, cb):
|
|
# get official API url
|
|
url = get_url(self.cfg, prefix=OFFICIAL_ROUTE_PREFIX)
|
|
self.get("progress", cb, base_url=url)
|