auto-sd-paint-ext/frontends/krita/krita_diff/script.py

456 lines
16 KiB
Python

import itertools
import os
import time
from typing import Union
from krita import (
Document,
Krita,
Node,
QImage,
QObject,
Qt,
QTimer,
Selection,
pyqtSignal,
)
from .client import Client
from .config import Config
from .defaults import (
ADD_MASK_TIMEOUT,
ERR_NO_DOCUMENT,
ETA_REFRESH_INTERVAL,
EXT_CFG_NAME,
STATE_INTERRUPT,
STATE_RESET_DEFAULT,
STATE_WAIT,
)
from .utils import (
b64_to_img,
find_optimal_selection_region,
get_desc_from_resp,
img_to_ba,
save_img,
)
# Does it actually have to be a QObject?
# The only possible use I see is for event emitting
class Script(QObject):
cfg: Config
"""config singleton"""
client: Client
"""API client singleton"""
status: str
"""Current status (shown in status bar)"""
app: Krita
"""Krita's Application instance (KDE Application)"""
doc: Document
"""Currently opened document if any"""
node: Node
"""Currently selected layer in Krita"""
selection: Selection
"""Selection region in Krita"""
x: int
"""Left position of selection"""
y: int
"""Top position of selection"""
width: int
"""Width of selection"""
height: int
"""Height of selection"""
status_changed = pyqtSignal(str)
config_updated = pyqtSignal()
def __init__(self):
super(Script, self).__init__()
# Persistent settings (should reload between Krita sessions)
self.cfg = Config()
# used for webUI scripts aka extensions not to be confused with their extensions
self.ext_cfg = Config(name=EXT_CFG_NAME, model=None)
self.client = Client(self.cfg, self.ext_cfg)
self.client.status.connect(self.status_changed.emit)
self.client.config_updated.connect(self.config_updated.emit)
self.eta_timer = QTimer()
self.eta_timer.setInterval(ETA_REFRESH_INTERVAL)
self.eta_timer.timeout.connect(lambda: self.action_update_eta())
def restore_defaults(self, if_empty=False):
"""Restore to default config."""
self.cfg.restore_defaults(not if_empty)
self.ext_cfg.config.remove("")
if not if_empty:
self.status_changed.emit(STATE_RESET_DEFAULT)
def update_selection(self):
"""Update references to key Krita objects as well as selection information."""
self.app = Krita.instance()
self.doc = self.app.activeDocument()
# self.doc doesnt exist at app startup
if not self.doc:
self.status_changed.emit(ERR_NO_DOCUMENT)
return
self.node = self.doc.activeNode()
self.selection = self.doc.selection()
is_not_selected = (
self.selection is None
or self.selection.width() < 1
or self.selection.height() < 1
)
if is_not_selected:
self.x = 0
self.y = 0
self.width = self.doc.width()
self.height = self.doc.height()
self.selection = None # for the two other cases of invalid selection
else:
self.x = self.selection.x()
self.y = self.selection.y()
self.width = self.selection.width()
self.height = self.selection.height()
assert (
self.doc.colorDepth() == "U8"
), f'Only "8-bit integer/channel" supported, Document Color Depth: {self.doc.colorDepth()}'
assert (
self.doc.colorModel() == "RGBA"
), f'Only "RGB/Alpha" supported, Document Color Model: {self.doc.colorModel()}'
def adjust_selection(self):
"""Adjust selection region to account for scaling and striding to prevent image stretch."""
if self.selection is not None and self.cfg("fix_aspect_ratio", bool):
x, y, width, height = find_optimal_selection_region(
self.cfg("sd_base_size", int),
self.cfg("sd_max_size", int),
self.x,
self.y,
self.width,
self.height,
self.doc.width(),
self.doc.height(),
)
self.x = x
self.y = y
self.width = width
self.height = height
def get_selection_image(self) -> QImage:
"""QImage of selection"""
return QImage(
self.doc.pixelData(self.x, self.y, self.width, self.height),
self.width,
self.height,
QImage.Format_RGBA8888,
).rgbSwapped()
def get_mask_image(self) -> Union[QImage, None]:
"""QImage of mask layer for inpainting"""
if self.node.type() not in {"paintlayer", "filelayer"}:
return None
return QImage(
self.node.pixelData(self.x, self.y, self.width, self.height),
self.width,
self.height,
QImage.Format_RGBA8888,
).rgbSwapped()
def img_inserter(self, x, y, width, height, group: str = None):
"""Return frozen image inserter to insert images as new layer."""
# Selection may change before callback, so freeze selection region
has_selection = self.selection is not None
glayer = self.doc.createGroupLayer(group) if group else None
def create_layer(name: str):
"""Create new layer in document or group"""
layer = self.doc.createNode(name, "paintLayer")
parent = self.doc.rootNode()
if glayer:
glayer.addChildNode(layer, None)
parent.addChildNode(glayer, None)
else:
parent.addChildNode(layer, None)
return layer
# TODO: Insert images inside a group layer for better organization
# Group layer name can contain model name, prompt, etc
def insert(layer_name, enc):
nonlocal x, y, width, height, has_selection
print(f"inserting layer {layer_name}")
print(f"data size: {len(enc)}")
# QImage.Format_RGB32 (4) is default format after decoding image
# QImage.Format_RGBA8888 (17) is format used in Krita tutorial
# both are compatible, & converting from 4 to 17 required a RGB swap
# Likewise for 5 & 18 (their RGBA counterparts)
image = b64_to_img(enc)
print(
f"image created: {image}, {image.width()}x{image.height()}, depth: {image.depth()}, format: {image.format()}"
)
# NOTE: Scaling is usually done by backend (although I am reconsidering this)
# The scaling here is for SD Upscale or Upscale on a selection region rather than whole image
# Image won't be scaled down ONLY if there is no selection; i.e. selecting whole image will scale down,
# not selecting anything won't scale down, leading to the canvas being resized afterwards
if has_selection and (image.width() != width or image.height() != height):
print(f"Rescaling image to selection: {width}x{height}")
image = image.scaled(
width, height, transformMode=Qt.SmoothTransformation
)
# Resize (not scale!) canvas if image is larger (i.e. outpainting or Upscale was used)
if image.width() > self.doc.width() or image.height() > self.doc.height():
# NOTE:
# - user's selection will be partially ignored if image is larger than canvas
# - it is complex to scale/resize the image such that image fits in the newly scaled selection
# - the canvas will still be resized even if the image fits after transparency masking
print("Image is larger than canvas! Resizing...")
new_width, new_height = self.doc.width(), self.doc.height()
if image.width() > self.doc.width():
x, width, new_width = 0, image.width(), image.width()
if image.height() > self.doc.height():
y, height, new_height = 0, image.height(), image.height()
self.doc.resizeImage(0, 0, new_width, new_height)
ba = img_to_ba(image)
layer = create_layer(layer_name)
# layer.setColorSpace() doesn't pernamently convert layer depth etc...
# Don't fail silently for setPixelData(); fails if bit depth or number of channels mismatch
size = ba.size()
expected = layer.pixelData(x, y, width, height).size()
assert expected == size, f"Raw data size: {size}, Expected size: {expected}"
print(f"inserting at x: {x}, y: {y}, w: {width}, h: {height}")
layer.setPixelData(ba, x, y, width, height)
return layer
if glayer:
return insert, glayer
return insert
def apply_txt2img(self):
# freeze selection region
insert, glayer = self.img_inserter(
self.x, self.y, self.width, self.height, group="a"
)
mask_trigger = self.transparency_mask_inserter()
def cb(response):
if len(self.client.long_reqs) == 1: # last request
self.eta_timer.stop()
assert response is not None, "Backend Error, check terminal"
outputs = response["outputs"]
glayer_name, layer_names = get_desc_from_resp(response, "txt2img")
layers = [
insert(name if name else f"txt2img {i + 1}", output)
for output, name, i in zip(outputs, layer_names, itertools.count())
]
if self.cfg("hide_layers", bool):
for layer in layers[:-1]:
layer.setVisible(False)
glayer.setName(glayer_name)
self.doc.refreshProjection()
mask_trigger(layers)
self.eta_timer.start(ETA_REFRESH_INTERVAL)
self.client.post_txt2img(
cb, self.width, self.height, self.selection is not None
)
def apply_img2img(self, is_inpaint):
insert, glayer = self.img_inserter(
self.x, self.y, self.width, self.height, group="a"
)
mask_trigger = self.transparency_mask_inserter()
mask_image = self.get_mask_image()
path = os.path.join(self.cfg("sample_path", str), f"{int(time.time())}.png")
mask_path = os.path.join(
self.cfg("sample_path", str), f"{int(time.time())}_mask.png"
)
if is_inpaint and mask_image is not None:
if self.cfg("save_temp_images", bool):
save_img(mask_image, mask_path)
# auto-hide mask layer before getting selection image
self.node.setVisible(False)
self.doc.refreshProjection()
sel_image = self.get_selection_image()
if self.cfg("save_temp_images", bool):
save_img(sel_image, path)
def cb(response):
if len(self.client.long_reqs) == 1: # last request
self.eta_timer.stop()
assert response is not None, "Backend Error, check terminal"
outputs = response["outputs"]
layer_name_prefix = "inpaint" if is_inpaint else "img2img"
glayer_name, layer_names = get_desc_from_resp(response, layer_name_prefix)
layers = [
insert(name if name else f"{layer_name_prefix} {i + 1}", output)
for output, name, i in zip(outputs, layer_names, itertools.count())
]
if self.cfg("hide_layers", bool):
for layer in layers[:-1]:
layer.setVisible(False)
glayer.setName(glayer_name)
self.doc.refreshProjection()
# dont need transparency mask for inpaint mode
if not is_inpaint:
mask_trigger(layers)
method = self.client.post_inpaint if is_inpaint else self.client.post_img2img
self.eta_timer.start()
method(
cb,
sel_image,
mask_image, # is unused by backend in img2img mode
self.selection is not None,
)
def apply_simple_upscale(self):
insert = self.img_inserter(self.x, self.y, self.width, self.height)
sel_image = self.get_selection_image()
path = os.path.join(self.cfg("sample_path", str), f"{int(time.time())}.png")
if self.cfg("save_temp_images", bool):
save_img(sel_image, path)
def cb(response):
assert response is not None, "Backend Error, check terminal"
output = response["output"]
insert(f"upscale", output)
self.doc.refreshProjection()
self.client.post_upscale(cb, sel_image)
def transparency_mask_inserter(self):
"""Mask out extra regions due to adjust_selection()."""
orig_selection = self.selection.duplicate() if self.selection else None
create_mask = self.cfg("create_mask_layer", bool)
add_mask_action = self.app.action("add_new_transparency_mask")
merge_mask_action = self.app.action("flatten_layer")
# This function is recursive to workaround race conditions when calling Krita's actions
def add_mask(layers: list, cur_selection):
if len(layers) < 1:
self.doc.setSelection(cur_selection) # reset to current selection
return
layer = layers.pop()
orig_visible = layer.visible()
orig_name = layer.name()
def restore():
# assume newly flattened layer is active
result = self.doc.activeNode()
result.setVisible(orig_visible)
result.setName(orig_name)
add_mask(layers, cur_selection)
layer.setVisible(True)
self.doc.setActiveNode(layer)
self.doc.setSelection(orig_selection)
add_mask_action.trigger()
if create_mask:
# collapse transparency mask by default
layer.setCollapsed(True)
layer.setVisible(orig_visible)
QTimer.singleShot(
ADD_MASK_TIMEOUT, lambda: add_mask(layers, cur_selection)
)
else:
# flatten transparency mask into layer
merge_mask_action.trigger()
QTimer.singleShot(ADD_MASK_TIMEOUT, lambda: restore())
def trigger_mask_adding(layers: list):
layers = layers[::-1] # causes final active layer to be the top one
def handle_mask():
cur_selection = self.selection.duplicate() if self.selection else None
add_mask(layers, cur_selection)
QTimer.singleShot(ADD_MASK_TIMEOUT, lambda: handle_mask())
return trigger_mask_adding
# Actions
def action_txt2img(self):
self.status_changed.emit(STATE_WAIT)
self.update_selection()
if not self.doc:
return
self.adjust_selection()
self.apply_txt2img()
def action_img2img(self):
self.status_changed.emit(STATE_WAIT)
self.update_selection()
if not self.doc:
return
self.adjust_selection()
self.apply_img2img(False)
def action_sd_upscale(self):
assert False, "disabled"
self.status_changed.emit(STATE_WAIT)
self.update_selection()
self.apply_img2img(mode=2)
def action_inpaint(self):
self.status_changed.emit(STATE_WAIT)
self.update_selection()
if not self.doc:
return
self.adjust_selection()
self.apply_img2img(True)
def action_simple_upscale(self):
self.status_changed.emit(STATE_WAIT)
self.update_selection()
if not self.doc:
return
self.apply_simple_upscale()
def action_update_config(self):
"""Update certain config/state from the backend."""
self.client.get_config()
def action_interrupt(self):
def cb(resp=None):
self.status_changed.emit(STATE_INTERRUPT)
self.client.post_interrupt(cb)
def action_update_eta(self):
def cb(resp=None):
# print(resp)
# NOTE: progress & eta_relative is bugged upstream when there is multiple jobs
# so we use a substitute that seems to work
state = resp["state"]
cur_step = state["sampling_step"]
total_steps = state["sampling_steps"]
# doesnt take into account batch count
num_jobs = len(self.client.long_reqs) - 1
self.status_changed.emit(
f"Step {cur_step}/{total_steps} ({num_jobs} in queue)"
)
self.client.get_progress(cb)
script = Script()