456 lines
16 KiB
Python
456 lines
16 KiB
Python
import itertools
|
|
import os
|
|
import time
|
|
from typing import Union
|
|
|
|
from krita import (
|
|
Document,
|
|
Krita,
|
|
Node,
|
|
QImage,
|
|
QObject,
|
|
Qt,
|
|
QTimer,
|
|
Selection,
|
|
pyqtSignal,
|
|
)
|
|
|
|
from .client import Client
|
|
from .config import Config
|
|
from .defaults import (
|
|
ADD_MASK_TIMEOUT,
|
|
ERR_NO_DOCUMENT,
|
|
ETA_REFRESH_INTERVAL,
|
|
EXT_CFG_NAME,
|
|
STATE_INTERRUPT,
|
|
STATE_RESET_DEFAULT,
|
|
STATE_WAIT,
|
|
)
|
|
from .utils import (
|
|
b64_to_img,
|
|
find_optimal_selection_region,
|
|
get_desc_from_resp,
|
|
img_to_ba,
|
|
save_img,
|
|
)
|
|
|
|
|
|
# Does it actually have to be a QObject?
|
|
# The only possible use I see is for event emitting
|
|
class Script(QObject):
|
|
cfg: Config
|
|
"""config singleton"""
|
|
client: Client
|
|
"""API client singleton"""
|
|
status: str
|
|
"""Current status (shown in status bar)"""
|
|
app: Krita
|
|
"""Krita's Application instance (KDE Application)"""
|
|
doc: Document
|
|
"""Currently opened document if any"""
|
|
node: Node
|
|
"""Currently selected layer in Krita"""
|
|
selection: Selection
|
|
"""Selection region in Krita"""
|
|
x: int
|
|
"""Left position of selection"""
|
|
y: int
|
|
"""Top position of selection"""
|
|
width: int
|
|
"""Width of selection"""
|
|
height: int
|
|
"""Height of selection"""
|
|
status_changed = pyqtSignal(str)
|
|
config_updated = pyqtSignal()
|
|
|
|
def __init__(self):
|
|
super(Script, self).__init__()
|
|
# Persistent settings (should reload between Krita sessions)
|
|
self.cfg = Config()
|
|
# used for webUI scripts aka extensions not to be confused with their extensions
|
|
self.ext_cfg = Config(name=EXT_CFG_NAME, model=None)
|
|
self.client = Client(self.cfg, self.ext_cfg)
|
|
self.client.status.connect(self.status_changed.emit)
|
|
self.client.config_updated.connect(self.config_updated.emit)
|
|
self.eta_timer = QTimer()
|
|
self.eta_timer.setInterval(ETA_REFRESH_INTERVAL)
|
|
self.eta_timer.timeout.connect(lambda: self.action_update_eta())
|
|
|
|
def restore_defaults(self, if_empty=False):
|
|
"""Restore to default config."""
|
|
self.cfg.restore_defaults(not if_empty)
|
|
self.ext_cfg.config.remove("")
|
|
|
|
if not if_empty:
|
|
self.status_changed.emit(STATE_RESET_DEFAULT)
|
|
|
|
def update_selection(self):
|
|
"""Update references to key Krita objects as well as selection information."""
|
|
self.app = Krita.instance()
|
|
self.doc = self.app.activeDocument()
|
|
|
|
# self.doc doesnt exist at app startup
|
|
if not self.doc:
|
|
self.status_changed.emit(ERR_NO_DOCUMENT)
|
|
return
|
|
|
|
self.node = self.doc.activeNode()
|
|
self.selection = self.doc.selection()
|
|
|
|
is_not_selected = (
|
|
self.selection is None
|
|
or self.selection.width() < 1
|
|
or self.selection.height() < 1
|
|
)
|
|
if is_not_selected:
|
|
self.x = 0
|
|
self.y = 0
|
|
self.width = self.doc.width()
|
|
self.height = self.doc.height()
|
|
self.selection = None # for the two other cases of invalid selection
|
|
else:
|
|
self.x = self.selection.x()
|
|
self.y = self.selection.y()
|
|
self.width = self.selection.width()
|
|
self.height = self.selection.height()
|
|
|
|
assert (
|
|
self.doc.colorDepth() == "U8"
|
|
), f'Only "8-bit integer/channel" supported, Document Color Depth: {self.doc.colorDepth()}'
|
|
assert (
|
|
self.doc.colorModel() == "RGBA"
|
|
), f'Only "RGB/Alpha" supported, Document Color Model: {self.doc.colorModel()}'
|
|
|
|
def adjust_selection(self):
|
|
"""Adjust selection region to account for scaling and striding to prevent image stretch."""
|
|
if self.selection is not None and self.cfg("fix_aspect_ratio", bool):
|
|
x, y, width, height = find_optimal_selection_region(
|
|
self.cfg("sd_base_size", int),
|
|
self.cfg("sd_max_size", int),
|
|
self.x,
|
|
self.y,
|
|
self.width,
|
|
self.height,
|
|
self.doc.width(),
|
|
self.doc.height(),
|
|
)
|
|
|
|
self.x = x
|
|
self.y = y
|
|
self.width = width
|
|
self.height = height
|
|
|
|
def get_selection_image(self) -> QImage:
|
|
"""QImage of selection"""
|
|
return QImage(
|
|
self.doc.pixelData(self.x, self.y, self.width, self.height),
|
|
self.width,
|
|
self.height,
|
|
QImage.Format_RGBA8888,
|
|
).rgbSwapped()
|
|
|
|
def get_mask_image(self) -> Union[QImage, None]:
|
|
"""QImage of mask layer for inpainting"""
|
|
if self.node.type() not in {"paintlayer", "filelayer"}:
|
|
return None
|
|
|
|
return QImage(
|
|
self.node.pixelData(self.x, self.y, self.width, self.height),
|
|
self.width,
|
|
self.height,
|
|
QImage.Format_RGBA8888,
|
|
).rgbSwapped()
|
|
|
|
def img_inserter(self, x, y, width, height, group: str = None):
|
|
"""Return frozen image inserter to insert images as new layer."""
|
|
# Selection may change before callback, so freeze selection region
|
|
has_selection = self.selection is not None
|
|
glayer = self.doc.createGroupLayer(group) if group else None
|
|
|
|
def create_layer(name: str):
|
|
"""Create new layer in document or group"""
|
|
layer = self.doc.createNode(name, "paintLayer")
|
|
parent = self.doc.rootNode()
|
|
if glayer:
|
|
glayer.addChildNode(layer, None)
|
|
parent.addChildNode(glayer, None)
|
|
else:
|
|
parent.addChildNode(layer, None)
|
|
return layer
|
|
|
|
# TODO: Insert images inside a group layer for better organization
|
|
# Group layer name can contain model name, prompt, etc
|
|
def insert(layer_name, enc):
|
|
nonlocal x, y, width, height, has_selection
|
|
print(f"inserting layer {layer_name}")
|
|
print(f"data size: {len(enc)}")
|
|
|
|
# QImage.Format_RGB32 (4) is default format after decoding image
|
|
# QImage.Format_RGBA8888 (17) is format used in Krita tutorial
|
|
# both are compatible, & converting from 4 to 17 required a RGB swap
|
|
# Likewise for 5 & 18 (their RGBA counterparts)
|
|
image = b64_to_img(enc)
|
|
print(
|
|
f"image created: {image}, {image.width()}x{image.height()}, depth: {image.depth()}, format: {image.format()}"
|
|
)
|
|
|
|
# NOTE: Scaling is usually done by backend (although I am reconsidering this)
|
|
# The scaling here is for SD Upscale or Upscale on a selection region rather than whole image
|
|
# Image won't be scaled down ONLY if there is no selection; i.e. selecting whole image will scale down,
|
|
# not selecting anything won't scale down, leading to the canvas being resized afterwards
|
|
if has_selection and (image.width() != width or image.height() != height):
|
|
print(f"Rescaling image to selection: {width}x{height}")
|
|
image = image.scaled(
|
|
width, height, transformMode=Qt.SmoothTransformation
|
|
)
|
|
|
|
# Resize (not scale!) canvas if image is larger (i.e. outpainting or Upscale was used)
|
|
if image.width() > self.doc.width() or image.height() > self.doc.height():
|
|
# NOTE:
|
|
# - user's selection will be partially ignored if image is larger than canvas
|
|
# - it is complex to scale/resize the image such that image fits in the newly scaled selection
|
|
# - the canvas will still be resized even if the image fits after transparency masking
|
|
print("Image is larger than canvas! Resizing...")
|
|
new_width, new_height = self.doc.width(), self.doc.height()
|
|
if image.width() > self.doc.width():
|
|
x, width, new_width = 0, image.width(), image.width()
|
|
if image.height() > self.doc.height():
|
|
y, height, new_height = 0, image.height(), image.height()
|
|
self.doc.resizeImage(0, 0, new_width, new_height)
|
|
|
|
ba = img_to_ba(image)
|
|
layer = create_layer(layer_name)
|
|
# layer.setColorSpace() doesn't pernamently convert layer depth etc...
|
|
|
|
# Don't fail silently for setPixelData(); fails if bit depth or number of channels mismatch
|
|
size = ba.size()
|
|
expected = layer.pixelData(x, y, width, height).size()
|
|
assert expected == size, f"Raw data size: {size}, Expected size: {expected}"
|
|
|
|
print(f"inserting at x: {x}, y: {y}, w: {width}, h: {height}")
|
|
layer.setPixelData(ba, x, y, width, height)
|
|
return layer
|
|
|
|
if glayer:
|
|
return insert, glayer
|
|
return insert
|
|
|
|
def apply_txt2img(self):
|
|
# freeze selection region
|
|
insert, glayer = self.img_inserter(
|
|
self.x, self.y, self.width, self.height, group="a"
|
|
)
|
|
mask_trigger = self.transparency_mask_inserter()
|
|
|
|
def cb(response):
|
|
if len(self.client.long_reqs) == 1: # last request
|
|
self.eta_timer.stop()
|
|
assert response is not None, "Backend Error, check terminal"
|
|
outputs = response["outputs"]
|
|
glayer_name, layer_names = get_desc_from_resp(response, "txt2img")
|
|
layers = [
|
|
insert(name if name else f"txt2img {i + 1}", output)
|
|
for output, name, i in zip(outputs, layer_names, itertools.count())
|
|
]
|
|
if self.cfg("hide_layers", bool):
|
|
for layer in layers[:-1]:
|
|
layer.setVisible(False)
|
|
glayer.setName(glayer_name)
|
|
self.doc.refreshProjection()
|
|
mask_trigger(layers)
|
|
|
|
self.eta_timer.start(ETA_REFRESH_INTERVAL)
|
|
self.client.post_txt2img(
|
|
cb, self.width, self.height, self.selection is not None
|
|
)
|
|
|
|
def apply_img2img(self, is_inpaint):
|
|
insert, glayer = self.img_inserter(
|
|
self.x, self.y, self.width, self.height, group="a"
|
|
)
|
|
mask_trigger = self.transparency_mask_inserter()
|
|
mask_image = self.get_mask_image()
|
|
|
|
path = os.path.join(self.cfg("sample_path", str), f"{int(time.time())}.png")
|
|
mask_path = os.path.join(
|
|
self.cfg("sample_path", str), f"{int(time.time())}_mask.png"
|
|
)
|
|
if is_inpaint and mask_image is not None:
|
|
if self.cfg("save_temp_images", bool):
|
|
save_img(mask_image, mask_path)
|
|
# auto-hide mask layer before getting selection image
|
|
self.node.setVisible(False)
|
|
self.doc.refreshProjection()
|
|
|
|
sel_image = self.get_selection_image()
|
|
if self.cfg("save_temp_images", bool):
|
|
save_img(sel_image, path)
|
|
|
|
def cb(response):
|
|
if len(self.client.long_reqs) == 1: # last request
|
|
self.eta_timer.stop()
|
|
assert response is not None, "Backend Error, check terminal"
|
|
|
|
outputs = response["outputs"]
|
|
layer_name_prefix = "inpaint" if is_inpaint else "img2img"
|
|
glayer_name, layer_names = get_desc_from_resp(response, layer_name_prefix)
|
|
layers = [
|
|
insert(name if name else f"{layer_name_prefix} {i + 1}", output)
|
|
for output, name, i in zip(outputs, layer_names, itertools.count())
|
|
]
|
|
if self.cfg("hide_layers", bool):
|
|
for layer in layers[:-1]:
|
|
layer.setVisible(False)
|
|
glayer.setName(glayer_name)
|
|
self.doc.refreshProjection()
|
|
# dont need transparency mask for inpaint mode
|
|
if not is_inpaint:
|
|
mask_trigger(layers)
|
|
|
|
method = self.client.post_inpaint if is_inpaint else self.client.post_img2img
|
|
self.eta_timer.start()
|
|
method(
|
|
cb,
|
|
sel_image,
|
|
mask_image, # is unused by backend in img2img mode
|
|
self.selection is not None,
|
|
)
|
|
|
|
def apply_simple_upscale(self):
|
|
insert = self.img_inserter(self.x, self.y, self.width, self.height)
|
|
sel_image = self.get_selection_image()
|
|
|
|
path = os.path.join(self.cfg("sample_path", str), f"{int(time.time())}.png")
|
|
if self.cfg("save_temp_images", bool):
|
|
save_img(sel_image, path)
|
|
|
|
def cb(response):
|
|
assert response is not None, "Backend Error, check terminal"
|
|
output = response["output"]
|
|
insert(f"upscale", output)
|
|
self.doc.refreshProjection()
|
|
|
|
self.client.post_upscale(cb, sel_image)
|
|
|
|
def transparency_mask_inserter(self):
|
|
"""Mask out extra regions due to adjust_selection()."""
|
|
orig_selection = self.selection.duplicate() if self.selection else None
|
|
create_mask = self.cfg("create_mask_layer", bool)
|
|
|
|
add_mask_action = self.app.action("add_new_transparency_mask")
|
|
merge_mask_action = self.app.action("flatten_layer")
|
|
|
|
# This function is recursive to workaround race conditions when calling Krita's actions
|
|
def add_mask(layers: list, cur_selection):
|
|
if len(layers) < 1:
|
|
self.doc.setSelection(cur_selection) # reset to current selection
|
|
return
|
|
layer = layers.pop()
|
|
|
|
orig_visible = layer.visible()
|
|
orig_name = layer.name()
|
|
|
|
def restore():
|
|
# assume newly flattened layer is active
|
|
result = self.doc.activeNode()
|
|
result.setVisible(orig_visible)
|
|
result.setName(orig_name)
|
|
|
|
add_mask(layers, cur_selection)
|
|
|
|
layer.setVisible(True)
|
|
self.doc.setActiveNode(layer)
|
|
self.doc.setSelection(orig_selection)
|
|
add_mask_action.trigger()
|
|
|
|
if create_mask:
|
|
# collapse transparency mask by default
|
|
layer.setCollapsed(True)
|
|
layer.setVisible(orig_visible)
|
|
QTimer.singleShot(
|
|
ADD_MASK_TIMEOUT, lambda: add_mask(layers, cur_selection)
|
|
)
|
|
else:
|
|
# flatten transparency mask into layer
|
|
merge_mask_action.trigger()
|
|
QTimer.singleShot(ADD_MASK_TIMEOUT, lambda: restore())
|
|
|
|
def trigger_mask_adding(layers: list):
|
|
layers = layers[::-1] # causes final active layer to be the top one
|
|
|
|
def handle_mask():
|
|
cur_selection = self.selection.duplicate() if self.selection else None
|
|
add_mask(layers, cur_selection)
|
|
|
|
QTimer.singleShot(ADD_MASK_TIMEOUT, lambda: handle_mask())
|
|
|
|
return trigger_mask_adding
|
|
|
|
# Actions
|
|
def action_txt2img(self):
|
|
self.status_changed.emit(STATE_WAIT)
|
|
self.update_selection()
|
|
if not self.doc:
|
|
return
|
|
self.adjust_selection()
|
|
self.apply_txt2img()
|
|
|
|
def action_img2img(self):
|
|
self.status_changed.emit(STATE_WAIT)
|
|
self.update_selection()
|
|
if not self.doc:
|
|
return
|
|
self.adjust_selection()
|
|
self.apply_img2img(False)
|
|
|
|
def action_sd_upscale(self):
|
|
assert False, "disabled"
|
|
self.status_changed.emit(STATE_WAIT)
|
|
self.update_selection()
|
|
self.apply_img2img(mode=2)
|
|
|
|
def action_inpaint(self):
|
|
self.status_changed.emit(STATE_WAIT)
|
|
self.update_selection()
|
|
if not self.doc:
|
|
return
|
|
self.adjust_selection()
|
|
self.apply_img2img(True)
|
|
|
|
def action_simple_upscale(self):
|
|
self.status_changed.emit(STATE_WAIT)
|
|
self.update_selection()
|
|
if not self.doc:
|
|
return
|
|
self.apply_simple_upscale()
|
|
|
|
def action_update_config(self):
|
|
"""Update certain config/state from the backend."""
|
|
self.client.get_config()
|
|
|
|
def action_interrupt(self):
|
|
def cb(resp=None):
|
|
self.status_changed.emit(STATE_INTERRUPT)
|
|
|
|
self.client.post_interrupt(cb)
|
|
|
|
def action_update_eta(self):
|
|
def cb(resp=None):
|
|
# print(resp)
|
|
# NOTE: progress & eta_relative is bugged upstream when there is multiple jobs
|
|
# so we use a substitute that seems to work
|
|
state = resp["state"]
|
|
cur_step = state["sampling_step"]
|
|
total_steps = state["sampling_steps"]
|
|
# doesnt take into account batch count
|
|
num_jobs = len(self.client.long_reqs) - 1
|
|
|
|
self.status_changed.emit(
|
|
f"Step {cur_step}/{total_steps} ({num_jobs} in queue)"
|
|
)
|
|
|
|
self.client.get_progress(cb)
|
|
|
|
|
|
script = Script()
|