main
wcde 2023-10-16 10:18:48 +03:00
parent e47ce01df6
commit 9700c14b60
2 changed files with 71 additions and 11 deletions

View File

@ -2,6 +2,10 @@
## Webui Extension for customizing highres fix and improve details (currently separated from original highres fix)
#### Update 16.10.23:
- added ControlNet support: choose preprocessor/model in CN settings, but don't enable unit
- added Lora support: put Lora in extension prompt to enable Lora only for upscaling, put Lora in negative prompt to disable active Lora
#### Update 02.07.23:
- code rewritten again
- simplified settings

View File

@ -3,14 +3,15 @@ from os.path import exists
from tqdm import trange
from modules import scripts, shared, processing, sd_samplers, script_callbacks, rng
from modules import devices, prompt_parser, sd_models
from modules import devices, prompt_parser, sd_models, extra_networks
import modules.images as images
import k_diffusion
import gradio as gr
import numpy as np
from PIL import Image
from PIL import Image, ImageEnhance
import torch
import importlib
def safe_import(import_name, pkg_name = None):
@ -32,6 +33,7 @@ safe_import('pathlib')
from omegaconf import DictConfig, OmegaConf
from pathlib import Path
import kornia
from skimage import exposure
config_path = Path(__file__).parent.resolve() / '../config.yaml'
@ -52,8 +54,12 @@ class CustomHiresFix(scripts.Script):
self.uncond = None
self.step = None
self.tv = None
self.width = None
self.width = None
self.height = None
self.use_cn = False
self.external_code = None
self.cn_image = None
self.cn_units = []
def title(self):
return "Custom Hires Fix"
@ -115,6 +121,11 @@ class CustomHiresFix(scripts.Script):
clip_skip = gr.Slider(minimum=0, maximum=5, step=1,
label="Clip skip for upscale (0 - not change)",
value=self.config.get('clip_skip', 0))
with gr.Row():
start_control_at = gr.Slider(minimum=0.0, maximum=0.7, step=0.01,
label="CN start for enabled units",
value=self.config.get('start_control_at', 0.0))
cn_ref = gr.Checkbox(label='Use last image for reference', value=self.config.get('cn_ref', False))
with gr.Row():
sampler = gr.Dropdown(['Restart', 'DPM++ 2M SDE', 'DPM++ 3M SDE', 'Restart + DPM++ 3M SDE'],
label='Sampler',
@ -127,20 +138,31 @@ class CustomHiresFix(scripts.Script):
width.change(fn=lambda x: gr.update(value=0), inputs=width, outputs=height)
height.change(fn=lambda x: gr.update(value=0), inputs=height, outputs=width)
ui = [enable, width, height, steps, first_upscaler, second_upscaler, first_latent, second_latent,
prompt, negative_prompt, strength, filter, filter_offset, denoise_offset, clip_skip, sampler]
ui = [enable, width, height, steps, first_upscaler, second_upscaler, first_latent, second_latent, prompt,
negative_prompt, strength, filter, filter_offset, denoise_offset, clip_skip, sampler, cn_ref, start_control_at]
for elem in ui:
setattr(elem, "do_not_save_to_config", True)
return ui
def process(self, p, *args, **kwargs):
self.p = p
self.cn_units = []
try:
self.external_code = importlib.import_module('extensions.sd-webui-controlnet.scripts.external_code', 'external_code')
cn_units = self.external_code.get_all_units_in_processing(p)
for unit in cn_units:
self.cn_units += [unit]
self.use_cn = len(self.cn_units) > 0
except ImportError:
self.use_cn = False
def postprocess_image(self, p, pp: scripts.PostprocessImageArgs,
enable, width, height, steps, first_upscaler, second_upscaler, first_latent, second_latent,
prompt, negative_prompt, strength, filter, filter_offset, denoise_offset, clip_skip, sampler
enable, width, height, steps, first_upscaler, second_upscaler, first_latent, second_latent, prompt,
negative_prompt, strength, filter, filter_offset, denoise_offset, clip_skip, sampler, cn_ref, start_control_at
):
if not enable:
return
self.step = 0
self.p = p
self.pp = pp
self.config.width = width
self.config.height = height
@ -157,15 +179,17 @@ class CustomHiresFix(scripts.Script):
self.config.denoise_offset = denoise_offset
self.config.clip_skip = clip_skip
self.config.sampler = sampler
self.config.cn_ref = cn_ref
self.config.start_control_at = start_control_at
self.orig_clip_skip = shared.opts.CLIP_stop_at_last_layers
self.orig_cfg = p.cfg_scale
if clip_skip > 0:
shared.opts.CLIP_stop_at_last_layers = clip_skip
if 'Restart' in self.config.sampler:
self.sampler = sd_samplers.create_sampler('Restart', shared.sd_model)
self.sampler = sd_samplers.create_sampler('Restart', p.sd_model)
else:
self.sampler = sd_samplers.create_sampler(sampler, shared.sd_model)
self.sampler = sd_samplers.create_sampler(sampler, p.sd_model)
def denoise_callback(params: script_callbacks.CFGDenoiserParams):
if params.sampling_step > 0:
@ -181,6 +205,13 @@ class CustomHiresFix(scripts.Script):
script_callbacks.on_cfg_denoiser(denoise_callback)
self.callback_set = True
_, loras_act = extra_networks.parse_prompt(prompt)
extra_networks.activate(p, loras_act)
_, loras_deact = extra_networks.parse_prompt(negative_prompt)
extra_networks.deactivate(p, loras_deact)
self.cn_image = pp.image
with devices.autocast():
shared.state.nextjob()
x = self.gen(pp.image)
@ -189,8 +220,24 @@ class CustomHiresFix(scripts.Script):
shared.opts.CLIP_stop_at_last_layers = self.orig_clip_skip
sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
pp.image = x
extra_networks.deactivate(p, loras_act)
OmegaConf.save(self.config, config_path)
def enable_cn(self, image: np.ndarray):
for unit in self.cn_units:
if unit.model != 'None':
unit.guidance_start = self.config.start_control_at if unit.enabled else unit.guidance_start
unit.processor_res = min(image.shape[0], image.shape[0])
unit.enabled = True
if unit.image is None:
unit.image = image
self.p.width = image.shape[1]
self.p.height = image.shape[0]
self.external_code.update_cn_script_in_processing(self.p, self.cn_units)
for script in self.p.scripts.alwayson_scripts:
if script.title().lower() == 'controlnet':
script.controlnet_hack(self.p)
def process_prompt(self):
prompt = self.p.prompt.strip().split('AND', 1)[0]
if self.config.prompt != '':
@ -218,7 +265,11 @@ class CustomHiresFix(scripts.Script):
self.height = self.config.height if self.config.height > 0 else int(self.config.width / ratio)
self.width = int((self.width - x.width) // 2 + x.width)
self.height = int((self.height - x.height) // 2 + x.height)
sd_models.apply_token_merging(self.p.sd_model, self.p.get_token_merging_ratio(for_hr=True) / 2)
if self.use_cn:
self.enable_cn(np.array(self.cn_image.resize((self.width, self.height))))
with devices.autocast(), torch.inference_mode():
self.process_prompt()
@ -235,7 +286,7 @@ class CustomHiresFix(scripts.Script):
if self.config.first_latent < 1:
x = images.resize_image(0, x, self.width, self.height,
upscaler_name=self.config.first_upscaler)
upscaler_name=self.config.first_upscaler)
image = np.array(x).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
decoded_sample = torch.from_numpy(image)
@ -291,6 +342,11 @@ class CustomHiresFix(scripts.Script):
self.width = self.config.width if self.config.width > 0 else int(self.config.height * ratio)
self.height = self.config.height if self.config.height > 0 else int(self.config.width / ratio)
sd_models.apply_token_merging(self.p.sd_model, self.p.get_token_merging_ratio(for_hr=True))
if self.use_cn:
self.cn_image = x if self.config.cn_ref else self.cn_image
self.enable_cn(np.array(self.cn_image.resize((self.width, self.height))))
with devices.autocast(), torch.inference_mode():
self.process_prompt()