add wavelet colorfix

pull/12/head
pkuliyi2015 2023-05-21 15:18:56 +00:00
parent d2dc384d18
commit 7378cce08e
4 changed files with 124 additions and 10 deletions

View File

@ -76,6 +76,11 @@ Relevant Links
- When enabling it, the script ignores your denoising strength and gives you much more detailed images, but also changes the color & sharpness significantly
- When disabling it, the script starts by adding some noise to your image. The result will be not fully detailed, even if you set denoising strength = 1 (but maybe aesthetically good). See [Comparison](https://imgsli.com/MTgwMTMx).
- If you disable Pure Noise, we recommend denoising strength=1
- What is "Color Fix"?
- This is to mitigate the color shift problem from StableSR and the tiling process.
- AdaIN simply adjusts the color statistics between the original and the outcome images. This is the official algorithm but ineffective in many cases.
- Wavelet decomposes the original and the outcome images into low and high frequency, and then replace the outcome image's low-frequency part (colors) with the original image's. This is very powerful for uneven color shifting. The algorithm is from GIMP and Krita, which will take several seconds for each image.
- When enabling color fix, the original image will also show up in your preview window, but will NOT be saved automatically.
### 6. Important Notice

View File

@ -76,6 +76,11 @@ Licensed under S-Lab License 1.0
- 启用这个选项时,脚本会忽略你的重绘幅度设置。产出将会是更详细的图像,但也会显著改变颜色和锐度。
- 禁用这个选项时脚本会开始添加一些噪声到你的图像。即使你将去噪强度设为1结果也不会那么的细节但可能更和谐好看。参见 [对比图](https://imgsli.com/MTgwMTMx)。
- 如果禁用Pure Noise推荐重绘幅度设置为1
- 什么是"颜色修正"
- 这是为了缓解来自StableSR和Tile处理过程中的颜色偏移问题。
- AdaIN简单地匹配原图和结果图的颜色统计信息。这是StableSR官方算法但常常效果不佳。
- Wavelet将原图和结果图分解为低频和高频然后用原图的低频信息颜色替换掉结果图的低频信息。该算法对于不均匀的颜色偏移非常强力。算法来自GIMP和Krita对每张图像需要几秒钟的时间。
- 启用颜色修正时,原图也会出现在您的预览窗口中,但不会被自动保存。
### 6. 重要问题
@ -86,7 +91,7 @@ Licensed under S-Lab License 1.0
- 如果你安装了可选的 VQVAE整个模型权重将与融合权重为 0 的官方模型相同。
- 但是,你的结果将**不如**官方结果,因为:
- 采样器差异:
-官方仓库进行 100 或 200 步的 legacy DDPM 采样,并使用自定义的时间步调度器,采样时不使用负提示。
- 官方仓库进行 100 或 200 步的 legacy DDPM 采样,并使用自定义的时间步调度器,采样时不使用负提示。
- 然而WebUI 不提供这样的采样器,必须带有负提示进行采样。**这是主要的差异。**
- VQVAE 解码器差异:
- 官方 VQVAE 解码器将一些编码器特征作为输入。

View File

@ -46,13 +46,14 @@ from pathlib import Path
from torch import Tensor
from tqdm import tqdm
from modules import scripts, processing, sd_samplers, devices
from modules import scripts, processing, sd_samplers, devices, images
from modules.processing import StableDiffusionProcessingImg2Img, Processed
from modules.shared import opts
from ldm.modules.diffusionmodules.openaimodel import UNetModel
from srmodule.spade import SPADELayers
from srmodule.struct_cond import EncoderUNetModelWT, build_unetwt
from srmodule.colorfix import fix_color
from srmodule.colorfix import adain_color_fix, wavelet_color_fix
SD_WEBUI_PATH = Path.cwd()
ME_PATH = SD_WEBUI_PATH / 'extensions' / 'sd-webui-stablesr'
@ -150,12 +151,14 @@ class Script(scripts.Script):
with gr.Row():
scale_factor = gr.Slider(minimum=1, maximum=16, step=0.1, value=2, label='Scale Factor', elem_id=f'StableSR-scale')
with gr.Row():
color_fix = gr.Dropdown(['None', 'Wavelet', 'AdaIN'], label="Color Fix", value='Wavelet', elem_id=f'StableSR-color-fix')
save_original = gr.Checkbox(label='Save Original', value=False, elem_id=f'StableSR-save-original', visible=color_fix.value != 'None')
color_fix.change(fn=lambda selected: gr.Checkbox.update(visible=selected != 'None'))
pure_noise = gr.Checkbox(label='Pure Noise', value=True, elem_id=f'StableSR-pure-noise')
color_fix = gr.Checkbox(label='Color Fix', value=True, elem_id=f'StableSR-color-fix')
return [model, scale_factor, pure_noise, color_fix]
return [model, scale_factor, pure_noise, color_fix, save_original]
def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:float, pure_noise: bool, color_fix:bool):
def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:float, pure_noise: bool, color_fix:str, save_original:bool) -> Processed:
if model == 'None':
# do clean up
@ -169,6 +172,10 @@ class Script(scripts.Script):
if not os.path.exists(self.model_list[model]):
raise gr.Error(f"Model {model} is not on your disk! Please refresh the model list!")
if color_fix not in ['None', 'Wavelet', 'AdaIN']:
print(f'[StableSR] Invalid color fix method: {color_fix}')
color_fix = 'None'
# upscale the image, set the ouput size
init_img: Image = p.init_images[0]
target_width = int(init_img.width * scale_factor)
@ -222,11 +229,40 @@ class Script(scripts.Script):
# Hook the unet, and unhook after processing.
try:
self.stablesr_model.hook(unet)
if color_fix != 'None':
p.do_not_save_samples = True
result: Processed = processing.process_images(p)
if color_fix:
if color_fix != 'None':
fixed_images = []
# fix the color
color_fix_func = wavelet_color_fix if color_fix == 'Wavelet' else adain_color_fix
for i in range(len(result.images)):
result.images[i] = fix_color(result.images[i], init_img)
try:
fixed_images.append(color_fix_func(result.images[i], init_img))
except Exception as e:
print(f'[StableSR] Error fixing color with default method: {e}')
# save the fixed color images
for i in range(len(fixed_images)):
try:
images.save_image(fixed_images[i], p.outpath_samples, "", result.seed, result.prompt, opts.samples_format, info=result.infotexts, p=p)
except Exception as e:
print(f'[StableSR] Error saving color fixed image: {e}')
if save_original:
for i in range(len(result.images)):
try:
images.save_image(result.images[i], p.outpath_samples, "", result.seed, result.prompt, opts.samples_format, info=result.infotexts, p=p, suffix="-before-color-fix")
except Exception as e:
print(f'[StableSR] Error saving original image: {e}')
result.images = result.images + fixed_images
return result
finally:
self.stablesr_model.unhook(unet)

View File

@ -1,9 +1,11 @@
import torch
from PIL import Image
from torch import Tensor
from torch.nn import functional as F
from torchvision.transforms import ToTensor, ToPILImage
def fix_color(target: Image, source: Image):
def adain_color_fix(target: Image, source: Image):
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = to_tensor(target).unsqueeze(0)
@ -18,6 +20,21 @@ def fix_color(target: Image, source: Image):
return result_image
def wavelet_color_fix(target: Image, source: Image):
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = to_tensor(target).unsqueeze(0)
source_tensor = to_tensor(source).unsqueeze(0)
# Apply wavelet reconstruction
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
return result_image
def calc_mean_std(feat: Tensor, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
@ -45,4 +62,55 @@ def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def wavelet_blur(image: Tensor, radius: int):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
# add channel dimensions to the kernel to make it a 4D tensor
kernel = kernel[None, None]
# repeat the kernel across all input channels
kernel = kernel.repeat(3, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution
output = F.conv2d(image, kernel, groups=3, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels=5):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq += (image - low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
# reconstruct the content feature with the style's high frequency
return content_high_freq + style_low_freq