auto-sd-paint-ext/frontends/krita/krita_diff/client.py

403 lines
16 KiB
Python

import json
import socket
from typing import Any, Dict, List
from urllib.error import URLError
from urllib.parse import urljoin, urlparse
from urllib.request import Request, urlopen
from krita import QObject, QThread, pyqtSignal
from .config import Config
from .defaults import (
ERR_BAD_URL,
ERR_NO_CONNECTION,
LONG_TIMEOUT,
OFFICIAL_ROUTE_PREFIX,
ROUTE_PREFIX,
SHORT_TIMEOUT,
STATE_DONE,
STATE_READY,
STATE_URLERROR,
THREADED,
)
from .utils import bytewise_xor, fix_prompt, get_ext_args, get_ext_key, img_to_b64
# NOTE: backend queues up responses, so no explicit need to block multiple requests
# except to prevent user from spamming themselves
# TODO: tab showing all queued up requests (local plugin instance only)
def get_url(cfg: Config, route: str = ..., prefix: str = ROUTE_PREFIX):
base = cfg("base_url", str)
if not urlparse(base).scheme in {"http", "https"}:
return None
url = urljoin(base, prefix)
if route is not ...:
url = urljoin(url, route)
# print("url:", url)
return url
# krita doesn't reexport QtNetwork
class AsyncRequest(QObject):
timeout = None
finished = pyqtSignal()
result = pyqtSignal(object)
error = pyqtSignal(Exception)
def __init__(
self,
url: str,
data: Any = None,
timeout: int = ...,
method: str = ...,
headers: dict = ...,
key: str = None,
):
"""Create an AsyncRequest object.
By default, AsyncRequest has no timeout, will infer whether it is "POST"
or "GET" based on the presence of `data` and uses JSON to transmit. It
also assumes the response is JSON.
Args:
url (str): URL to request from.
data (Any, optional): Payload to send. Defaults to None.
timeout (int, optional): Timeout for request. Defaults to `...`.
method (str, optional): Which HTTP method to use. Defaults to `...`.
key (Union[str, None], Optional): Key to use for encryption/decryption. Defaults to None.
"""
super(AsyncRequest, self).__init__()
self.url = url
self.data = None if data is None else json.dumps(data).encode("utf-8")
self.headers = {} if headers is ... else headers
self.key = None
if isinstance(key, str) and key.strip() != "":
self.key = key.strip().encode("utf-8")
if self.key is not None:
self.headers["X-Encrypted-Body"] = "XOR"
if timeout is not ...:
self.timeout = timeout
if method is ...:
self.method = "GET" if data is None else "POST"
else:
self.method = method
if self.data is not None:
if self.key is not None:
# print(f"Encrypting with ${self.key}:\n{self.data}")
self.data = bytewise_xor(self.data, self.key)
# print(f"Encrypt Result:\n{self.data}")
self.headers["Content-Type"] = "application/json"
self.headers["Content-Length"] = str(len(self.data))
def run(self):
req = Request(self.url, headers=self.headers, method=self.method)
try:
with urlopen(req, self.data, self.timeout) as res:
data = res.read()
enc_type = res.getheader("X-Encrypted-Body", None)
assert enc_type in {"XOR", None}, "Unknown server encryption!"
if enc_type == "XOR":
assert self.key, f"Key needed to decrypt server response!"
# print(f"Decrypting with ${self.key}:\n{data}")
data = bytewise_xor(data, self.key)
# print(f"Decrypt Result:\n{data}")
self.result.emit(json.loads(data))
except Exception as e:
self.error.emit(e)
finally:
self.finished.emit()
@classmethod
def request(cls, *args, **kwargs):
req = cls(*args, **kwargs)
if THREADED:
thread = QThread()
# NOTE: need to keep reference to thread or it gets destroyed
req.thread = thread
req.moveToThread(thread)
thread.started.connect(req.run)
req.finished.connect(thread.quit)
# NOTE: is this a memory leak?
# For some reason, deleteLater occurs while thread is still running, resulting in crash
# req.finished.connect(req.deleteLater)
# thread.finished.connect(thread.deleteLater)
return req, lambda: thread.start()
else:
return req, lambda: req.run()
class Client(QObject):
status = pyqtSignal(str)
config_updated = pyqtSignal()
def __init__(self, cfg: Config, ext_cfg: Config):
"""It is highly dependent on config's structure to the point it writes directly to it. :/"""
super(Client, self).__init__()
self.cfg = cfg
self.ext_cfg = ext_cfg
self.short_reqs = set()
self.long_reqs = set()
# NOTE: this is a hacky workaround for detecting if backend is reachable
self.is_connected = False
def handle_api_error(self, exc: Exception):
"""Handle exceptions that can occur while interacting with the backend."""
self.is_connected = False
try:
# wtf python? socket raises an error that isnt an Exception??
if isinstance(exc, socket.timeout):
raise TimeoutError
else:
raise exc
except URLError as e:
self.status.emit(f"{STATE_URLERROR}: {e.reason}")
except TimeoutError as e:
self.status.emit(f"{STATE_URLERROR}: response timed out")
except json.JSONDecodeError as e:
self.status.emit(f"{STATE_URLERROR}: invalid JSON response")
except ValueError as e:
self.status.emit(f"{STATE_URLERROR}: Invalid backend URL")
except ConnectionError as e:
self.status.emit(f"{STATE_URLERROR}: connection error during request")
except Exception as e:
# self.status.emit(f"{STATE_URLERROR}: Unexpected Error")
# self.status.emit(str(e))
assert False, e
def post(
self, route, body, cb, base_url=..., is_long=True, ignore_no_connection=False
):
if not ignore_no_connection and not self.is_connected:
self.status.emit(ERR_NO_CONNECTION)
return
url = get_url(self.cfg, route) if base_url is ... else urljoin(base_url, route)
if not url:
self.status.emit(ERR_BAD_URL)
return
# TODO: how to cancel this? destroy the thread after sending API interrupt request?
req, start = AsyncRequest.request(
url,
body,
LONG_TIMEOUT if is_long else SHORT_TIMEOUT,
key=self.cfg("encryption_key"),
)
if is_long:
self.long_reqs.add(req)
else:
self.short_reqs.add(req)
def handler():
self.long_reqs.discard(req)
self.short_reqs.discard(req)
if is_long and len(self.long_reqs) == 0:
self.status.emit(STATE_DONE)
req.result.connect(cb)
req.error.connect(lambda e: self.handle_api_error(e))
req.finished.connect(handler)
start()
def get(self, route, cb, base_url=..., is_long=False, ignore_no_connection=False):
self.post(
route,
None,
cb,
base_url=base_url,
is_long=is_long,
ignore_no_connection=ignore_no_connection,
)
def common_params(self, has_selection):
"""Parameters nearly all the post routes share."""
tiling = self.cfg("sd_tiling", bool) and not (
self.cfg("only_full_img_tiling", bool) and has_selection
)
# its fine to stuff extra stuff here; pydantic will shave off irrelevant params
params = dict(
sd_model=self.cfg("sd_model", str),
batch_count=self.cfg("sd_batch_count", int),
batch_size=self.cfg("sd_batch_size", int),
base_size=self.cfg("sd_base_size", int),
max_size=self.cfg("sd_max_size", int),
tiling=tiling,
upscaler_name=self.cfg("upscaler_name", str),
restore_faces=self.cfg("face_restorer_model", str) != "None",
face_restorer=self.cfg("face_restorer_model", str),
codeformer_weight=self.cfg("codeformer_weight", float),
filter_nsfw=self.cfg("filter_nsfw", bool),
do_exact_steps=self.cfg("do_exact_steps", bool),
include_grid=self.cfg("include_grid", bool),
save_samples=self.cfg("save_temp_images", bool),
)
return params
def get_config(self):
def cb(obj):
try:
assert "sample_path" in obj
assert len(obj["upscalers"]) > 0
assert len(obj["samplers"]) > 0
assert len(obj["samplers_img2img"]) > 0
assert len(obj["face_restorers"]) > 0
assert len(obj["sd_models"]) > 0
assert len(obj["scripts_txt2img"]) > 0
assert len(obj["scripts_img2img"]) > 0
except:
self.status.emit(
f"{STATE_URLERROR}: incompatible response, are you running the right API?"
)
print("Invalid Response:\n", obj)
return
# replace only after verifying
self.cfg.set("sample_path", obj["sample_path"])
# NOTE: sorting these lists is risky; ivent 100% verified that I removed all reliance on indexes
self.cfg.set("upscaler_list", obj["upscalers"])
self.cfg.set("txt2img_sampler_list", obj["samplers"])
self.cfg.set("img2img_sampler_list", obj["samplers_img2img"])
self.cfg.set("inpaint_sampler_list", obj["samplers_img2img"])
self.cfg.set("txt2img_script_list", list(obj["scripts_txt2img"].keys()))
self.cfg.set("img2img_script_list", list(obj["scripts_img2img"].keys()))
self.cfg.set("inpaint_script_list", list(obj["scripts_img2img"].keys()))
self.cfg.set("face_restorer_model_list", obj["face_restorers"])
self.cfg.set("sd_model_list", obj["sd_models"])
# extension script cfg
obj["scripts_inpaint"] = obj["scripts_img2img"]
for ext_type in {"scripts_txt2img", "scripts_img2img", "scripts_inpaint"}:
metadata: Dict[str, List[dict]] = obj[ext_type]
self.ext_cfg.set(f"{ext_type}_len", len(metadata))
for ext_name, ext_meta in metadata.items():
old_val = self.ext_cfg(get_ext_key(ext_type, ext_name))
new_val = json.dumps(ext_meta)
if new_val != old_val:
self.ext_cfg.set(get_ext_key(ext_type, ext_name), new_val)
for i, opt in enumerate(ext_meta):
key = get_ext_key(ext_type, ext_name, i)
self.ext_cfg.set(key, opt["val"])
self.is_connected = True
self.status.emit(STATE_READY)
self.config_updated.emit()
self.get("config", cb, ignore_no_connection=True)
def post_txt2img(self, cb, width, height, has_selection):
params = dict(orig_width=width, orig_height=height)
if not self.cfg("just_use_yaml", bool):
seed = (
int(self.cfg("txt2img_seed", str)) # Qt casts int as 32-bit int
if not self.cfg("txt2img_seed", str).strip() == ""
else -1
)
ext_name = self.cfg("txt2img_script", str)
ext_args = get_ext_args(self.ext_cfg, "scripts_txt2img", ext_name)
params.update(self.common_params(has_selection))
params.update(
prompt=fix_prompt(self.cfg("txt2img_prompt", str)),
negative_prompt=fix_prompt(self.cfg("txt2img_negative_prompt", str)),
sampler_name=self.cfg("txt2img_sampler", str),
steps=self.cfg("txt2img_steps", int),
cfg_scale=self.cfg("txt2img_cfg_scale", float),
seed=seed,
highres_fix=self.cfg("txt2img_highres", bool),
denoising_strength=self.cfg("txt2img_denoising_strength", float),
script=ext_name,
script_args=ext_args,
)
self.post("txt2img", params, cb)
def post_img2img(self, cb, src_img, mask_img, has_selection):
params = dict(is_inpaint=False, src_img=img_to_b64(src_img))
if not self.cfg("just_use_yaml", bool):
seed = (
int(self.cfg("img2img_seed", str)) # Qt casts int as 32-bit int
if not self.cfg("img2img_seed", str).strip() == ""
else -1
)
ext_name = self.cfg("img2img_script", str)
ext_args = get_ext_args(self.ext_cfg, "scripts_img2img", ext_name)
params.update(self.common_params(has_selection))
params.update(
prompt=fix_prompt(self.cfg("img2img_prompt", str)),
negative_prompt=fix_prompt(self.cfg("img2img_negative_prompt", str)),
sampler_name=self.cfg("img2img_sampler", str),
steps=self.cfg("img2img_steps", int),
cfg_scale=self.cfg("img2img_cfg_scale", float),
denoising_strength=self.cfg("img2img_denoising_strength", float),
color_correct=self.cfg("img2img_color_correct", bool),
script=ext_name,
script_args=ext_args,
seed=seed,
)
self.post("img2img", params, cb)
def post_inpaint(self, cb, src_img, mask_img, has_selection):
assert mask_img, "Inpaint layer is needed for inpainting!"
params = dict(
is_inpaint=True, src_img=img_to_b64(src_img), mask_img=img_to_b64(mask_img)
)
if not self.cfg("just_use_yaml", bool):
seed = (
int(self.cfg("inpaint_seed", str)) # Qt casts int as 32-bit int
if not self.cfg("inpaint_seed", str).strip() == ""
else -1
)
fill = self.cfg("inpaint_fill_list", "QStringList").index(
self.cfg("inpaint_fill", str)
)
ext_name = self.cfg("inpaint_script", str)
ext_args = get_ext_args(self.ext_cfg, "scripts_inpaint", ext_name)
params.update(self.common_params(has_selection))
params.update(
prompt=fix_prompt(self.cfg("inpaint_prompt", str)),
negative_prompt=fix_prompt(self.cfg("inpaint_negative_prompt", str)),
sampler_name=self.cfg("inpaint_sampler", str),
steps=self.cfg("inpaint_steps", int),
cfg_scale=self.cfg("inpaint_cfg_scale", float),
denoising_strength=self.cfg("inpaint_denoising_strength", float),
color_correct=self.cfg("inpaint_color_correct", bool),
script=ext_name,
script_args=ext_args,
seed=seed,
invert_mask=self.cfg("inpaint_invert_mask", bool),
# mask_blur=self.cfg("inpaint_mask_blur", int),
inpainting_fill=fill,
# inpaint_full_res=self.cfg("inpaint_full_res", bool),
# inpaint_full_res_padding=self.cfg("inpaint_full_res_padding", int),
inpaint_mask_weight=self.cfg("inpaint_mask_weight", float),
include_grid=False, # it is never useful for inpaint mode
)
self.post("img2img", params, cb)
def post_upscale(self, cb, src_img):
params = (
{
"src_img": img_to_b64(src_img),
"upscaler_name": self.cfg("upscale_upscaler_name", str),
"downscale_first": self.cfg("upscale_downscale_first", bool),
}
if not self.cfg("just_use_yaml", bool)
else {"src_img": img_to_b64(src_img)}
)
self.post("upscale", params, cb)
def post_interrupt(self, cb):
# get official API url
url = get_url(self.cfg, prefix=OFFICIAL_ROUTE_PREFIX)
self.post("interrupt", {}, cb, base_url=url)
def get_progress(self, cb):
# get official API url
url = get_url(self.cfg, prefix=OFFICIAL_ROUTE_PREFIX)
self.get("progress", cb, base_url=url)