261 lines
9.1 KiB
Python
261 lines
9.1 KiB
Python
import math
|
|
from enum import Enum
|
|
from collections import namedtuple
|
|
from types import MethodType
|
|
|
|
import cv2
|
|
import torch
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
from modules import devices, shared, prompt_parser, extra_networks
|
|
from modules import sd_samplers_common
|
|
from modules.shared import state
|
|
from modules.processing import opt_f
|
|
|
|
from tile_utils.typing import *
|
|
|
|
state: State
|
|
|
|
|
|
class ComparableEnum(Enum):
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
if isinstance(other, str): return self.value == other
|
|
elif isinstance(other, ComparableEnum): return self.value == other.value
|
|
else: raise TypeError(f'unsupported type: {type(other)}')
|
|
|
|
class Method(ComparableEnum):
|
|
|
|
MULTI_DIFF = 'MultiDiffusion'
|
|
MIX_DIFF = 'Mixture of Diffusers'
|
|
|
|
class Method_2(ComparableEnum):
|
|
DEMO_FU = "DemoFusion"
|
|
|
|
class BlendMode(Enum): # i.e. LayerType
|
|
|
|
FOREGROUND = 'Foreground'
|
|
BACKGROUND = 'Background'
|
|
|
|
BBoxSettings = namedtuple('BBoxSettings', ['enable', 'x', 'y', 'w', 'h', 'prompt', 'neg_prompt', 'blend_mode', 'feather_ratio', 'seed'])
|
|
NoiseInverseCache = namedtuple('NoiseInversionCache', ['model_hash', 'x0', 'xt', 'noise_inversion_steps', 'retouch', 'prompts'])
|
|
DEFAULT_BBOX_SETTINGS = BBoxSettings(False, 0.4, 0.4, 0.2, 0.2, '', '', BlendMode.BACKGROUND.value, 0.2, -1)
|
|
NUM_BBOX_PARAMS = len(BBoxSettings._fields)
|
|
|
|
|
|
def build_bbox_settings(bbox_control_states:List[Any]) -> Dict[int, BBoxSettings]:
|
|
settings = {}
|
|
for index, i in enumerate(range(0, len(bbox_control_states), NUM_BBOX_PARAMS)):
|
|
setting = BBoxSettings(*bbox_control_states[i:i+NUM_BBOX_PARAMS])
|
|
# for float x, y, w, h, feather_ratio, keeps 4 digits
|
|
setting = setting._replace(
|
|
x=round(setting.x, 4),
|
|
y=round(setting.y, 4),
|
|
w=round(setting.w, 4),
|
|
h=round(setting.h, 4),
|
|
feather_ratio=round(setting.feather_ratio, 4),
|
|
seed=int(setting.seed),
|
|
)
|
|
# sanity check
|
|
if not setting.enable or setting.x > 1.0 or setting.y > 1.0 or setting.w <= 0.0 or setting.h <= 0.0: continue
|
|
settings[index] = setting
|
|
return settings
|
|
|
|
def gr_value(value=None, visible=None):
|
|
return {"value": value, "visible": visible, "__type__": "update"}
|
|
|
|
|
|
class BBox:
|
|
|
|
''' grid bbox '''
|
|
|
|
def __init__(self, x:int, y:int, w:int, h:int):
|
|
self.x = x
|
|
self.y = y
|
|
self.w = w
|
|
self.h = h
|
|
self.box = [x, y, x+w, y+h]
|
|
self.slicer = slice(None), slice(None), slice(y, y+h), slice(x, x+w)
|
|
|
|
def __getitem__(self, idx:int) -> int:
|
|
return self.box[idx]
|
|
|
|
class CustomBBox(BBox):
|
|
|
|
''' region control bbox '''
|
|
|
|
def __init__(self, x:int, y:int, w:int, h:int, prompt:str, neg_prompt:str, blend_mode:str, feather_radio:float, seed:int):
|
|
super().__init__(x, y, w, h)
|
|
self.prompt = prompt
|
|
self.neg_prompt = neg_prompt
|
|
self.blend_mode = BlendMode(blend_mode)
|
|
self.feather_ratio = max(min(feather_radio, 1.0), 0.0)
|
|
self.seed = seed
|
|
# initialize necessary fields
|
|
self.feather_mask = feather_mask(self.w, self.h, self.feather_ratio) if self.blend_mode == BlendMode.FOREGROUND else None
|
|
self.cond: MulticondLearnedConditioning = None
|
|
self.extra_network_data: DefaultDict[List[ExtraNetworkParams]] = None
|
|
self.uncond: List[List[ScheduledPromptConditioning]] = None
|
|
|
|
|
|
class Prompt:
|
|
|
|
''' prompts helper '''
|
|
|
|
@staticmethod
|
|
def apply_styles(prompts:List[str], styles=None) -> List[str]:
|
|
if not styles: return prompts
|
|
return [shared.prompt_styles.apply_styles_to_prompt(p, styles) for p in prompts]
|
|
|
|
@staticmethod
|
|
def append_prompt(prompts:List[str], prompt:str='') -> List[str]:
|
|
if not prompt: return prompts
|
|
return [f'{p}, {prompt}' for p in prompts]
|
|
|
|
class Condition:
|
|
|
|
''' CLIP cond helper '''
|
|
|
|
@staticmethod
|
|
def get_custom_cond(prompts:List[str], prompt, steps:int, styles=None) -> Tuple[Cond, ExtraNetworkData]:
|
|
prompt = Prompt.apply_styles([prompt], styles)[0]
|
|
_, extra_network_data = extra_networks.parse_prompts([prompt])
|
|
prompts = Prompt.append_prompt(prompts, prompt)
|
|
prompts = Prompt.apply_styles(prompts, styles)
|
|
cond = Condition.get_cond(prompts, steps)
|
|
return cond, extra_network_data
|
|
|
|
@staticmethod
|
|
def get_cond(prompts, steps:int):
|
|
prompts, _ = extra_networks.parse_prompts(prompts)
|
|
cond = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, steps)
|
|
return cond
|
|
|
|
@staticmethod
|
|
def get_uncond(neg_prompts:List[str], steps:int, styles=None) -> Uncond:
|
|
neg_prompts = Prompt.apply_styles(neg_prompts, styles)
|
|
uncond = prompt_parser.get_learned_conditioning(shared.sd_model, neg_prompts, steps)
|
|
return uncond
|
|
|
|
@staticmethod
|
|
def reconstruct_cond(cond:Cond, step:int) -> Tensor:
|
|
list_of_what, tensor = prompt_parser.reconstruct_multicond_batch(cond, step)
|
|
return tensor
|
|
|
|
def reconstruct_uncond(uncond:Uncond, step:int) -> Tensor:
|
|
tensor = prompt_parser.reconstruct_cond_batch(uncond, step)
|
|
return tensor
|
|
|
|
|
|
def splitable(w:int, h:int, tile_w:int, tile_h:int, overlap:int=16) -> bool:
|
|
w, h = w // opt_f, h // opt_f
|
|
min_tile_size = min(tile_w, tile_h)
|
|
if overlap >= min_tile_size:
|
|
overlap = min_tile_size - 4
|
|
cols = math.ceil((w - overlap) / (tile_w - overlap))
|
|
rows = math.ceil((h - overlap) / (tile_h - overlap))
|
|
return cols > 1 or rows > 1
|
|
|
|
def split_bboxes(w:int, h:int, tile_w:int, tile_h:int, overlap:int=16, init_weight:Union[Tensor, float]=1.0) -> Tuple[List[BBox], Tensor]:
|
|
cols = math.ceil((w - overlap) / (tile_w - overlap))
|
|
rows = math.ceil((h - overlap) / (tile_h - overlap))
|
|
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
|
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
|
|
|
bbox_list: List[BBox] = []
|
|
weight = torch.zeros((1, 1, h, w), device=devices.device, dtype=torch.float32)
|
|
for row in range(rows):
|
|
y = min(int(row * dy), h - tile_h)
|
|
for col in range(cols):
|
|
x = min(int(col * dx), w - tile_w)
|
|
|
|
bbox = BBox(x, y, tile_w, tile_h)
|
|
bbox_list.append(bbox)
|
|
weight[bbox.slicer] += init_weight
|
|
|
|
return bbox_list, weight
|
|
|
|
|
|
def gaussian_weights(tile_w:int, tile_h:int) -> Tensor:
|
|
'''
|
|
Copy from the original implementation of Mixture of Diffusers
|
|
https://github.com/albarji/mixture-of-diffusers/blob/master/mixdiff/tiling.py
|
|
This generates gaussian weights to smooth the noise of each tile.
|
|
This is critical for this method to work.
|
|
'''
|
|
from numpy import pi, exp, sqrt
|
|
|
|
f = lambda x, midpoint, var=0.01: exp(-(x-midpoint)*(x-midpoint) / (tile_w*tile_w) / (2*var)) / sqrt(2*pi*var)
|
|
x_probs = [f(x, (tile_w - 1) / 2) for x in range(tile_w)] # -1 because index goes from 0 to latent_width - 1
|
|
y_probs = [f(y, tile_h / 2) for y in range(tile_h)]
|
|
|
|
w = np.outer(y_probs, x_probs)
|
|
return torch.from_numpy(w).to(devices.device, dtype=torch.float32)
|
|
|
|
def feather_mask(w:int, h:int, ratio:float) -> Tensor:
|
|
'''Generate a feather mask for the bbox'''
|
|
|
|
mask = np.ones((h, w), dtype=np.float32)
|
|
feather_radius = int(min(w//2, h//2) * ratio)
|
|
# Generate the mask via gaussian weights
|
|
# adjust the weight near the edge. the closer to the edge, the lower the weight
|
|
# weight = ( dist / feather_radius) ** 2
|
|
for i in range(h//2):
|
|
for j in range(w//2):
|
|
dist = min(i, j)
|
|
if dist >= feather_radius: continue
|
|
weight = (dist / feather_radius) ** 2
|
|
mask[i, j] = weight
|
|
mask[i, w-j-1] = weight
|
|
mask[h-i-1, j] = weight
|
|
mask[h-i-1, w-j-1] = weight
|
|
|
|
return torch.from_numpy(mask).to(devices.device, dtype=torch.float32)
|
|
|
|
def get_retouch_mask(img_input: np.ndarray, kernel_size: int) -> np.ndarray:
|
|
'''
|
|
Return the area where the image is retouched.
|
|
Copy from Zhihu.com
|
|
'''
|
|
step = 1
|
|
kernel = (kernel_size, kernel_size)
|
|
|
|
img = img_input.astype(np.float32)/255.0
|
|
sz = img.shape[:2]
|
|
sz1 = (int(round(sz[1] * step)), int(round(sz[0] * step)))
|
|
sz2 = (int(round(kernel[0] * step)), int(round(kernel[0] * step)))
|
|
sI = cv2.resize(img, sz1, interpolation=cv2.INTER_LINEAR)
|
|
sp = cv2.resize(img, sz1, interpolation=cv2.INTER_LINEAR)
|
|
msI = cv2.blur(sI, sz2)
|
|
msp = cv2.blur(sp, sz2)
|
|
msII = cv2.blur(sI*sI, sz2)
|
|
msIp = cv2.blur(sI*sp, sz2)
|
|
vsI = msII - msI*msI
|
|
csIp = msIp - msI*msp
|
|
recA = csIp/(vsI+0.01)
|
|
recB = msp - recA*msI
|
|
mA = cv2.resize(recA, (sz[1],sz[0]), interpolation=cv2.INTER_LINEAR)
|
|
mB = cv2.resize(recB, (sz[1],sz[0]), interpolation=cv2.INTER_LINEAR)
|
|
|
|
gf = mA * img + mB
|
|
gf -= img
|
|
gf *= 255
|
|
gf = gf.astype(np.uint8)
|
|
gf = gf.clip(0, 255)
|
|
gf = gf.astype(np.float32)/255.0
|
|
return gf
|
|
|
|
|
|
def null_decorator(fn):
|
|
def wrapper(*args, **kwargs):
|
|
return fn(*args, **kwargs)
|
|
return wrapper
|
|
|
|
keep_signature = null_decorator
|
|
controlnet = null_decorator
|
|
stablesr = null_decorator
|
|
grid_bbox = null_decorator
|
|
custom_bbox = null_decorator
|
|
noise_inverse = null_decorator
|