add wavelet colorfix
parent
d2dc384d18
commit
7378cce08e
|
|
@ -76,6 +76,11 @@ Relevant Links
|
|||
- When enabling it, the script ignores your denoising strength and gives you much more detailed images, but also changes the color & sharpness significantly
|
||||
- When disabling it, the script starts by adding some noise to your image. The result will be not fully detailed, even if you set denoising strength = 1 (but maybe aesthetically good). See [Comparison](https://imgsli.com/MTgwMTMx).
|
||||
- If you disable Pure Noise, we recommend denoising strength=1
|
||||
- What is "Color Fix"?
|
||||
- This is to mitigate the color shift problem from StableSR and the tiling process.
|
||||
- AdaIN simply adjusts the color statistics between the original and the outcome images. This is the official algorithm but ineffective in many cases.
|
||||
- Wavelet decomposes the original and the outcome images into low and high frequency, and then replace the outcome image's low-frequency part (colors) with the original image's. This is very powerful for uneven color shifting. The algorithm is from GIMP and Krita, which will take several seconds for each image.
|
||||
- When enabling color fix, the original image will also show up in your preview window, but will NOT be saved automatically.
|
||||
|
||||
### 6. Important Notice
|
||||
|
||||
|
|
|
|||
|
|
@ -76,6 +76,11 @@ Licensed under S-Lab License 1.0
|
|||
- 启用这个选项时,脚本会忽略你的重绘幅度设置。产出将会是更详细的图像,但也会显著改变颜色和锐度。
|
||||
- 禁用这个选项时,脚本会开始添加一些噪声到你的图像。即使你将去噪强度设为1,结果也不会那么的细节(但可能更和谐好看)。参见 [对比图](https://imgsli.com/MTgwMTMx)。
|
||||
- 如果禁用Pure Noise,推荐重绘幅度设置为1
|
||||
- 什么是"颜色修正"?
|
||||
- 这是为了缓解来自StableSR和Tile处理过程中的颜色偏移问题。
|
||||
- AdaIN简单地匹配原图和结果图的颜色统计信息。这是StableSR官方算法,但常常效果不佳。
|
||||
- Wavelet将原图和结果图分解为低频和高频,然后用原图的低频信息(颜色)替换掉结果图的低频信息。该算法对于不均匀的颜色偏移非常强力。算法来自GIMP和Krita,对每张图像需要几秒钟的时间。
|
||||
- 启用颜色修正时,原图也会出现在您的预览窗口中,但不会被自动保存。
|
||||
|
||||
### 6. 重要问题
|
||||
|
||||
|
|
@ -86,7 +91,7 @@ Licensed under S-Lab License 1.0
|
|||
- 如果你安装了可选的 VQVAE,整个模型权重将与融合权重为 0 的官方模型相同。
|
||||
- 但是,你的结果将**不如**官方结果,因为:
|
||||
- 采样器差异:
|
||||
-官方仓库进行 100 或 200 步的 legacy DDPM 采样,并使用自定义的时间步调度器,采样时不使用负提示。
|
||||
- 官方仓库进行 100 或 200 步的 legacy DDPM 采样,并使用自定义的时间步调度器,采样时不使用负提示。
|
||||
- 然而,WebUI 不提供这样的采样器,必须带有负提示进行采样。**这是主要的差异。**
|
||||
- VQVAE 解码器差异:
|
||||
- 官方 VQVAE 解码器将一些编码器特征作为输入。
|
||||
|
|
|
|||
|
|
@ -46,13 +46,14 @@ from pathlib import Path
|
|||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules import scripts, processing, sd_samplers, devices
|
||||
from modules import scripts, processing, sd_samplers, devices, images
|
||||
from modules.processing import StableDiffusionProcessingImg2Img, Processed
|
||||
from modules.shared import opts
|
||||
from ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||
|
||||
from srmodule.spade import SPADELayers
|
||||
from srmodule.struct_cond import EncoderUNetModelWT, build_unetwt
|
||||
from srmodule.colorfix import fix_color
|
||||
from srmodule.colorfix import adain_color_fix, wavelet_color_fix
|
||||
|
||||
SD_WEBUI_PATH = Path.cwd()
|
||||
ME_PATH = SD_WEBUI_PATH / 'extensions' / 'sd-webui-stablesr'
|
||||
|
|
@ -150,12 +151,14 @@ class Script(scripts.Script):
|
|||
with gr.Row():
|
||||
scale_factor = gr.Slider(minimum=1, maximum=16, step=0.1, value=2, label='Scale Factor', elem_id=f'StableSR-scale')
|
||||
with gr.Row():
|
||||
color_fix = gr.Dropdown(['None', 'Wavelet', 'AdaIN'], label="Color Fix", value='Wavelet', elem_id=f'StableSR-color-fix')
|
||||
save_original = gr.Checkbox(label='Save Original', value=False, elem_id=f'StableSR-save-original', visible=color_fix.value != 'None')
|
||||
color_fix.change(fn=lambda selected: gr.Checkbox.update(visible=selected != 'None'))
|
||||
pure_noise = gr.Checkbox(label='Pure Noise', value=True, elem_id=f'StableSR-pure-noise')
|
||||
color_fix = gr.Checkbox(label='Color Fix', value=True, elem_id=f'StableSR-color-fix')
|
||||
|
||||
return [model, scale_factor, pure_noise, color_fix]
|
||||
return [model, scale_factor, pure_noise, color_fix, save_original]
|
||||
|
||||
def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:float, pure_noise: bool, color_fix:bool):
|
||||
def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:float, pure_noise: bool, color_fix:str, save_original:bool) -> Processed:
|
||||
|
||||
if model == 'None':
|
||||
# do clean up
|
||||
|
|
@ -169,6 +172,10 @@ class Script(scripts.Script):
|
|||
if not os.path.exists(self.model_list[model]):
|
||||
raise gr.Error(f"Model {model} is not on your disk! Please refresh the model list!")
|
||||
|
||||
if color_fix not in ['None', 'Wavelet', 'AdaIN']:
|
||||
print(f'[StableSR] Invalid color fix method: {color_fix}')
|
||||
color_fix = 'None'
|
||||
|
||||
# upscale the image, set the ouput size
|
||||
init_img: Image = p.init_images[0]
|
||||
target_width = int(init_img.width * scale_factor)
|
||||
|
|
@ -222,11 +229,40 @@ class Script(scripts.Script):
|
|||
# Hook the unet, and unhook after processing.
|
||||
try:
|
||||
self.stablesr_model.hook(unet)
|
||||
|
||||
if color_fix != 'None':
|
||||
p.do_not_save_samples = True
|
||||
|
||||
result: Processed = processing.process_images(p)
|
||||
if color_fix:
|
||||
|
||||
if color_fix != 'None':
|
||||
|
||||
fixed_images = []
|
||||
# fix the color
|
||||
color_fix_func = wavelet_color_fix if color_fix == 'Wavelet' else adain_color_fix
|
||||
for i in range(len(result.images)):
|
||||
result.images[i] = fix_color(result.images[i], init_img)
|
||||
try:
|
||||
fixed_images.append(color_fix_func(result.images[i], init_img))
|
||||
except Exception as e:
|
||||
print(f'[StableSR] Error fixing color with default method: {e}')
|
||||
|
||||
# save the fixed color images
|
||||
for i in range(len(fixed_images)):
|
||||
try:
|
||||
images.save_image(fixed_images[i], p.outpath_samples, "", result.seed, result.prompt, opts.samples_format, info=result.infotexts, p=p)
|
||||
except Exception as e:
|
||||
print(f'[StableSR] Error saving color fixed image: {e}')
|
||||
|
||||
if save_original:
|
||||
for i in range(len(result.images)):
|
||||
try:
|
||||
images.save_image(result.images[i], p.outpath_samples, "", result.seed, result.prompt, opts.samples_format, info=result.infotexts, p=p, suffix="-before-color-fix")
|
||||
except Exception as e:
|
||||
print(f'[StableSR] Error saving original image: {e}')
|
||||
result.images = result.images + fixed_images
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
self.stablesr_model.unhook(unet)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torchvision.transforms import ToTensor, ToPILImage
|
||||
|
||||
def fix_color(target: Image, source: Image):
|
||||
def adain_color_fix(target: Image, source: Image):
|
||||
# Convert images to tensors
|
||||
to_tensor = ToTensor()
|
||||
target_tensor = to_tensor(target).unsqueeze(0)
|
||||
|
|
@ -18,6 +20,21 @@ def fix_color(target: Image, source: Image):
|
|||
|
||||
return result_image
|
||||
|
||||
def wavelet_color_fix(target: Image, source: Image):
|
||||
# Convert images to tensors
|
||||
to_tensor = ToTensor()
|
||||
target_tensor = to_tensor(target).unsqueeze(0)
|
||||
source_tensor = to_tensor(source).unsqueeze(0)
|
||||
|
||||
# Apply wavelet reconstruction
|
||||
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
|
||||
|
||||
# Convert tensor back to image
|
||||
to_image = ToPILImage()
|
||||
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
||||
|
||||
return result_image
|
||||
|
||||
def calc_mean_std(feat: Tensor, eps=1e-5):
|
||||
"""Calculate mean and std for adaptive_instance_normalization.
|
||||
Args:
|
||||
|
|
@ -45,4 +62,55 @@ def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
|
|||
style_mean, style_std = calc_mean_std(style_feat)
|
||||
content_mean, content_std = calc_mean_std(content_feat)
|
||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
def wavelet_blur(image: Tensor, radius: int):
|
||||
"""
|
||||
Apply wavelet blur to the input tensor.
|
||||
"""
|
||||
# input shape: (1, 3, H, W)
|
||||
# convolution kernel
|
||||
# input shape: (1, 3, H, W)
|
||||
# convolution kernel
|
||||
kernel_vals = [
|
||||
[0.0625, 0.125, 0.0625],
|
||||
[0.125, 0.25, 0.125],
|
||||
[0.0625, 0.125, 0.0625],
|
||||
]
|
||||
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||
# add channel dimensions to the kernel to make it a 4D tensor
|
||||
kernel = kernel[None, None]
|
||||
# repeat the kernel across all input channels
|
||||
kernel = kernel.repeat(3, 1, 1, 1)
|
||||
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
||||
# apply convolution
|
||||
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
||||
return output
|
||||
|
||||
def wavelet_decomposition(image: Tensor, levels=5):
|
||||
"""
|
||||
Apply wavelet decomposition to the input tensor.
|
||||
This function only returns the low frequency & the high frequency.
|
||||
"""
|
||||
high_freq = torch.zeros_like(image)
|
||||
for i in range(levels):
|
||||
radius = 2 ** i
|
||||
low_freq = wavelet_blur(image, radius)
|
||||
high_freq += (image - low_freq)
|
||||
image = low_freq
|
||||
|
||||
return high_freq, low_freq
|
||||
|
||||
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
||||
"""
|
||||
Apply wavelet decomposition, so that the content will have the same color as the style.
|
||||
"""
|
||||
# calculate the wavelet decomposition of the content feature
|
||||
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||
del content_low_freq
|
||||
# calculate the wavelet decomposition of the style feature
|
||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||
del style_high_freq
|
||||
# reconstruct the content feature with the style's high frequency
|
||||
return content_high_freq + style_low_freq
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue