auto-sd-paint-ext/backend/utils.py

330 lines
9.7 KiB
Python

from __future__ import annotations
import inspect
import logging
import os
import secrets
from base64 import b64decode, b64encode
from io import BytesIO
from itertools import cycle
from math import ceil
import modules
import yaml
from modules import shared
from PIL import Image
from pydantic import BaseModel
from .config import CONFIG_PATH, ENCRYPT_FILE, LOGGER_NAME, MainConfig
log = logging.getLogger(LOGGER_NAME)
def load_config():
"""Load default config (including those not exposed in the API yet) from
`CONFIG_PATH` in the current working directory.
Will create `CONFIG_PATH` if it has yet to exist using `MainConfig` from
`config.py`.
Returns:
MainConfig: config
"""
if not os.path.isfile(CONFIG_PATH):
cfg = MainConfig()
with open(CONFIG_PATH, "w") as f:
yaml.safe_dump(cfg.dict(), f)
with open(CONFIG_PATH) as file:
obj = yaml.safe_load(file)
return MainConfig.parse_obj(obj)
def merge_default_config(config: BaseModel, default: BaseModel):
"""Replace unset and None fields in opt with values from default with the
same field name in place.
Unset fields does not include fields that are explicitly set to None but
includes fields with a default value due to being unset.
Args:
config (BaseModel): Config object.
default (BaseModel): Default to merge from.
Returns:
BaseModel: Modified config.
"""
for field in config.__fields__:
if not field in config.__fields_set__ or field is None:
setattr(config, field, getattr(default, field, None))
return config
def prepare_backend(opt: BaseModel):
"""Misc configuration and preparation tasks before calling internal API.
Currently includes:
- Ensuring the output/input folders exist
- Set the global face restorer model to the selected one
- Set the global SD model to the selected one
- Set the global upscaler to the selected one
- Set other misc global webUI/backend settings
Args:
opt (BaseModel): Option/Request object
"""
# the `shared` module handles app state for the underlying codebase
if hasattr(opt, "face_restorer"):
shared.opts.face_restoration_model = opt.face_restorer
shared.opts.code_former_weight = opt.codeformer_weight
if hasattr(opt, "sd_model"):
shared.opts.sd_model_checkpoint = opt.sd_model
modules.sd_models.reload_model_weights(shared.sd_model)
if hasattr(opt, "upscaler_name"):
shared.opts.upscaler_for_img2img = opt.upscaler_name
if hasattr(opt, "color_correct"):
shared.opts.img2img_color_correction = opt.color_correct
shared.opts.img2img_fix_steps = opt.do_exact_steps
if hasattr(opt, "filter_nsfw"):
shared.opts.filter_nsfw = opt.filter_nsfw
if hasattr(opt, "inpaint_mask_weight"):
shared.opts.inpainting_mask_weight = opt.inpaint_mask_weight
# Ensure the output/input folders exist
if hasattr(opt, "sample_path"):
os.makedirs(opt.sample_path, exist_ok=True)
def optional(*fields):
"""Decorator function used to modify a pydantic model's fields to all be optional.
Alternatively, you can also pass the field names that should be made optional as arguments
to the decorator.
Taken from https://github.com/samuelcolvin/pydantic/issues/1223#issuecomment-775363074
"""
def dec(_cls):
for field in fields:
_cls.__fields__[field].required = False
return _cls
if fields and inspect.isclass(fields[0]) and issubclass(fields[0], BaseModel):
cls = fields[0]
fields = cls.__fields__
return dec(cls)
return dec
def save_img(image: Image.Image, sample_path: str, filename: str):
"""Saves an image.
Args:
image (Image): Image to save.
sample_path (str): Folder to save the image in.
filename (str): Name to save the image as.
Returns:
str: Absolute path where the image was saved.
"""
path = os.path.join(sample_path, filename)
image.save(path)
return os.path.abspath(path)
def img_to_b64(image: Image.Image):
"""Convert an image to base64-encoded string.
Args:
image (Image): Image to encode.
Returns:
str: Base64-encoded image.
"""
buf = BytesIO()
image.save(buf, format="png")
return b64encode(buf.getvalue()).decode("utf-8")
def b64_to_img(enc: str):
"""Convert base64-encoded string to image.
Args:
enc (str): Base64-encoded image.
Returns:
Image: Image.
"""
return Image.open(BytesIO(b64decode(enc)))
def sddebz_highres_fix(
base_size: int, max_size: int, orig_width: int, orig_height: int
):
"""Calculate an appropiate image resolution given the base input size of the
model and max input size allowed.
The max input size is due to how Stable Diffusion currently handles resolutions
larger than its base/native input size of 512, which can cause weird issues
such as duplicated features in the image. Hence, it is typically better to
render at a smaller appropiate resolution before using other methods to upscale
to the original resolution. Setting max_size to 512, matching the base_size,
imitates how the highres fix works.
Stable Diffusion also messes up for resolutions smaller than 512. In which case,
it is better to render at the base resolution before downscaling to the original.
This method requires less user input than the builtin highres fix, which uses
firstphase_width and firstphase_height.
The original plugin writer, @sddebz, wrote this. I modified it to `ceil`
instead of `round` to make selected region resizing easier in the plugin, and
to avoid rounding to 0.
Args:
base_size (int): Native/base input size of the model.
max_size (int): Max input size to accept.
orig_width (int): Original width requested.
orig_height (int): Original height requested.
Returns:
Tuple[int, int]: Appropiate (width, height) to use for the model.
"""
def rnd(r, x, z=64):
"""Scale dimension x with stride z while attempting to preserve aspect ratio r."""
return z * ceil(r * x / z)
ratio = orig_width / orig_height
# height is smaller dimension
if orig_width > orig_height:
width, height = rnd(ratio, base_size), base_size
if width > max_size:
width, height = max_size, rnd(1 / ratio, max_size)
# width is smaller dimension
else:
width, height = base_size, rnd(1 / ratio, base_size)
if height > max_size:
width, height = rnd(ratio, max_size), max_size
new_ratio = width / height
log.info(
f"img size: {orig_width}x{orig_height} -> {width}x{height}, "
f"aspect ratio: {ratio:.2f} -> {new_ratio:.2f}, {100 * (new_ratio - ratio) / ratio :.2f}% change"
)
return width, height
def parse_prompt(val):
"""Parse different representations of prompt/negative prompt.
Args:
val (Any): Prompt to parse.
Raises:
SyntaxError: Value of the prompt key cannot be parsed.
Returns:
str: Correctly formatted prompt.
"""
if val is None:
return ""
# Below cases are meant for prompts read from the yaml config
if isinstance(val, str):
return val
if isinstance(val, list):
return ", ".join(val)
if isinstance(val, dict):
prompt = ""
for item, weight in val.items():
if not prompt == "":
prompt += " "
if weight is None:
prompt += f"{item}"
else:
prompt += f"({item}:{weight})"
return prompt
raise SyntaxError(f"prompt field in {CONFIG_PATH} is invalid")
def get_sampler_index(sampler_name: str):
"""Get index of sampler by name.
Args:
sampler_name (str): Exact name of sampler.
Raises:
KeyError: Sampler cannot be found.
Returns:
int: Index of sampler.
"""
for index, sampler in enumerate(modules.sd_samplers.samplers):
if sampler_name == sampler.name or sampler_name in sampler.aliases:
return index
raise KeyError(f"sampler not found: {sampler_name}")
def get_upscaler_index(upscaler_name: str):
"""Get index of upscaler by name.
Args:
upscaler_name (str): Exact name of upscaler.
Raises:
KeyError: Upscaler cannot be found.
Returns:
int: Index of sampler.
"""
for index, upscaler in enumerate(shared.sd_upscalers):
if upscaler.name == upscaler_name:
return index
raise KeyError(f"upscaler not found: {upscaler_name}")
def prepare_mask(mask: Image.Image):
"""Prepare mask for usage.
Args:
mask (Image): mask.
Returns:
Image: The luminance mask.
"""
return mask.getchannel("A")
def bytewise_xor(msg: bytes, key: bytes):
"""Used for decrypting/encrypting request/response bodies."""
return bytes(v ^ k for v, k in zip(msg, cycle(key)))
def get_encrypt_key():
"""Read encryption key from file."""
try:
with open(ENCRYPT_FILE) as f:
return f.read().strip().encode("utf-8")
except:
if not os.path.exists(ENCRYPT_FILE):
log.warning(
f"Encryption key file doesn't exist at {os.path.abspath(ENCRYPT_FILE)}."
)
log.warning(f"Creating random encryption key.")
with open(ENCRYPT_FILE, "w") as f:
f.write(secrets.token_hex(16))
log.warning(
f"Key in {ENCRYPT_FILE} is completely optional. It can be used to encrypt messages between backend & Krita and is editable."
)
return get_encrypt_key()
return None