unprompted/lib_unprompted/helpers.py

312 lines
8.0 KiB
Python

# Contains helper methods that do not rely on Unprompted class data
# nor are they exclusively useful for said class
pil_resampling_dict = {}
pil_resampling_dict["Nearest Neighbor"] = 0
pil_resampling_dict["Box"] = 4
pil_resampling_dict["Bilinear"] = 2
pil_resampling_dict["Hamming"] = 5
pil_resampling_dict["Bicubic"] = 3
pil_resampling_dict["Lanczos"] = 1
def strip_str(string, chop):
"""Removes substring `chop` from the beginning or end of given `string`"""
while True:
if chop and string.endswith(chop):
string = string[:-len(chop)]
else:
break
while True:
if chop and string.startswith(chop):
string = string[len(chop):]
else:
break
return string
def sigmoid(x):
import math
return 1 / (1 + math.exp(-x))
def is_equal(var_a, var_b):
"""Checks if two variables equal each other, taking care to account for datatypes."""
if (is_float(var_a)):
var_a = float(var_a)
if (is_float(var_b)):
var_b = float(var_b)
if (str(var_a) == str(var_b)):
return True
else:
return False
def is_not_equal(var_a, var_b):
"""Checks if two variables do not equal each other, taking care to account for datatypes."""
return not is_equal(var_a, var_b)
def is_float(value):
"""Tests whether variable is a float by attempting to convert it to a float."""
try:
float(value)
return True
except:
return False
def is_int(value):
"""Tests whether variable is an integer by attempting to convert it to an integer."""
try:
int(value)
return True
except:
return False
def ensure(var, datatype):
"""Ensures that a variable is a given datatype"""
if isinstance(var, datatype):
return var
else:
if datatype == list:
return [var]
return datatype(var)
def autocast(var):
"""Converts a variable between string, int, and float depending on how it's formatted"""
original_var = var
if original_var == "inf" or original_var == "-inf":
return (original_var)
elif (is_float(var)):
var = float(var)
if int(var) == var and "." not in str(original_var):
var = int(var)
elif (is_int(var)):
var = int(var)
return (var)
def pil_to_cv2(img):
import cv2, numpy
return cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)
def cv2_to_pil(img):
import cv2
from PIL import Image
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
def tensor_to_pil(tensor):
import numpy as np
from PIL import Image
# Move the tensor to CPU if it's not already
tensor = tensor.cpu()
# Remove the batch dimension if it exists
if tensor.ndim == 4:
tensor = tensor[0]
# Convert the tensor to a numpy array
array = tensor.numpy()
# If the tensor is in the format (C, H, W), transpose it to (H, W, C)
# if array.shape[0] == 3:
# array = array.transpose(1, 2, 0)
# Convert the numpy array to an image
array = (array * 255).astype(np.uint8)
image = Image.fromarray(array)
return image
def pil_to_tensor(image):
import numpy as np
from PIL import Image
import torch
# Convert the PIL image to a numpy array
array = np.array(image)
# Normalize the numpy array to the range [0, 1]
array = array.astype(np.float32) / 255.0
# If the array is in the format (H, W, C), transpose it to (C, H, W)
# if array.ndim == 3:
# array = array.transpose(2, 0, 1)
# Convert the numpy array to a tensor
tensor = torch.from_numpy(array)
# Add a batch dimension
tensor = tensor.unsqueeze(0)
return tensor
def str_to_rgb(color_string):
"""Converts a color string to a tuple of RGB values"""
if color_string[0].isdigit():
return tuple(map(int, color_string.split(',')))
elif color_string.startswith("#"):
return bytes.fromhex(color_string[1:])
def str_to_pil(string):
from PIL import Image
log = get_logger()
if isinstance(string, str) and string.startswith("<PIL.Image.Image"):
# Get the PIL object from the memory address
# <PIL.Image.Image image mode=RGBA size=1024x1024 at 0x...>
try:
import ctypes
import re
# Extract the memory address from the string
address = re.search(r"at (0x[0-9A-F]+)", string).group(1)
# Convert the memory address to an integer
address = int(address, 16)
# Create a ctypes pointer to the memory address.
img = ctypes.cast(address, ctypes.py_object).value
# Validate the object
if not isinstance(img, Image.Image):
log.error(f"Failed to extract PIL image from memory address: {address}, {string}, {type(img)}")
return False
log.debug(f"Successfully extracted PIL image from memory address: {img}")
return img
except:
log.exception(f"Failed to extract PIL image from memory address: {string}")
return False
else:
try:
import glob, random, os
files = glob.glob(string)
if (len(files) == 0):
log.error(f"No files found at this location: {string}")
return ("")
file = random.choice(files)
log.debug(f"Loading file: {file}")
if not os.path.exists(file):
log.error(f"File does not exist: {file}")
return False
img = Image.open(string)
return img
except:
log.exception(f"Failed to open image: {string}")
return False
def get_logger(logger=None):
if not logger:
try:
import logging
logger = logging.getLogger("Unprompted")
except:
logger = print
return logger
def download_file(filename, url, logger=None, overwrite=False, headers=None):
import os, requests
log = get_logger(logger)
if overwrite or not os.path.exists(filename):
# Make sure directory structure exists
os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True)
log.info(f"Downloading file into: {filename}...")
response = requests.get(url, stream=True, headers=headers)
if response.status_code != 200:
log.error(f"Error when trying to download `{url}` to `{filename}`. Dtatus code received: {response.status_code}")
return False
try:
with open(filename, 'wb') as fout:
for block in response.iter_content(4096):
fout.write(block)
except:
log.exception(f"Error when writing download to `{filename}`.")
return False
return True
def import_file(full_name, path):
"""Allows importing of modules from full filepath, not sure why Python requires a helper function for this in 2023"""
from importlib import util
spec = util.spec_from_file_location(full_name, path)
mod = util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
def list_set(this_list, index, value, null_value=False):
"""Helper function to set array indexes that are outside the array's current length"""
while (len(this_list) <= index):
this_list.append(null_value)
this_list[index] = value
def str_with_ext(path, default_ext=".json"):
import os
if os.path.exists(path) or default_ext in path:
return path
return path + default_ext
def create_load_json(file_path, default_data={}, encoding="utf8"):
import json
try:
# If the file already exists, load its content
with open(file_path, "r", encoding=encoding) as file:
data = json.load(file)
# If the file doesn't exist, create it with default data
except FileNotFoundError:
with open(file_path, "w", encoding=encoding) as file:
json.dump(default_data, file, indent=4)
data = default_data
return data
def unsharp_mask(image, amount=1.0, kernel_size=(5, 5), sigma=1.0, threshold=0):
"""Return a sharpened version of the image, using an unsharp mask."""
import numpy, cv2
from PIL import Image
image = numpy.array(image).astype(numpy.uint8)
blurred = cv2.GaussianBlur(image, kernel_size, sigma)
sharpened = float(amount + 1) * image - float(amount) * blurred
sharpened = numpy.maximum(sharpened, numpy.zeros(sharpened.shape))
sharpened = numpy.minimum(sharpened, 255 * numpy.ones(sharpened.shape))
sharpened = sharpened.round().astype(numpy.uint8)
if threshold > 0:
low_contrast_mask = numpy.absolute(image - blurred) < threshold
numpy.copyto(sharpened, image, where=low_contrast_mask)
return Image.fromarray(sharpened)
# Helper class that converts kwargs to attribute notation
# Many libraries expect to be fed options with argparse,
# which is not so straightforward inside of an A1111 extension
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self