parent
4c111ecaa2
commit
b2f9c73532
|
|
@ -13,7 +13,8 @@ jobs:
|
|||
with:
|
||||
repository: 'AUTOMATIC1111/stable-diffusion-webui'
|
||||
path: 'stable-diffusion-webui'
|
||||
|
||||
ref: '5ab7f213bec2f816f9c5644becb32eb72c8ffb89'
|
||||
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
|
|
|
|||
18
README.md
18
README.md
|
|
@ -94,6 +94,24 @@ Now you can control which aspect is more important (your prompt or your ControlN
|
|||
</tr>
|
||||
</table>
|
||||
|
||||
### Reference-Only Control
|
||||
|
||||
Now we have a `reference-only` preprocessor that does not require any control models. It can guide the diffusion directly using images as references.
|
||||
|
||||
(Prompt "a dog running on grassland, best quality, ...")
|
||||
|
||||

|
||||
|
||||
This method is similar to inpaint-based reference but it does not make your image disordered.
|
||||
|
||||
Many professional A1111 users know a trick to diffuse image with references by inpaint. For example, if you have a 512x512 image of a dog, and want to generate another 512x512 image with the same dog, some users will connect the 512x512 dog image and a 512x512 blank image into a 1024x512 image, send to inpaint, and mask out the blank 512x512 part to diffuse a dog with similar appearance. However, that method is usually not very satisfying since images are connected and many distortions will appear.
|
||||
|
||||
This `reference-only` ControlNet can directly link the attention layers of your SD to any independent images, so that your SD will read arbitary images for reference. You need at least ControlNet 1.1.153 to use it.
|
||||
|
||||
To use, just select `reference-only` as preprocessor and put an image. Your SD will just use the image as reference.
|
||||
|
||||
*Note that this method is as "non-opinioned" as possible. It only contains very basic connection codes, without any personal preferences, to connect the attention layers with your reference images. However, even if we tried best to not include any opinioned codes, we still need to write some subjective implementations to deal with weighting, cfg-scale, etc - tech report is on the way.*
|
||||
|
||||
# Technical Documents
|
||||
|
||||
See also the documents of ControlNet 1.1:
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 606 KiB |
|
|
@ -1,19 +1,12 @@
|
|||
import gc
|
||||
import inspect
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from copy import copy
|
||||
import base64
|
||||
from typing import Union, Dict, Optional, List
|
||||
import importlib
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
import modules.scripts as scripts
|
||||
from modules import shared, devices, script_callbacks, processing, masking, images
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
from einops import rearrange
|
||||
from scripts import global_state, hook, external_code, processor, batch_hijack, controlnet_version
|
||||
|
|
@ -30,12 +23,16 @@ from scripts.hook import ControlParams, UnetHook, ControlModelType
|
|||
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img
|
||||
from modules.images import save_image
|
||||
from modules.ui_components import FormRow
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import base64
|
||||
|
||||
from pathlib import Path
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from scripts.lvminthin import lvmin_thin, nake_nms
|
||||
from torchvision.transforms import Resize, InterpolationMode, CenterCrop, Compose
|
||||
from scripts.processor import preprocessor_sliders_config, flag_preprocessor_resolution
|
||||
from scripts.processor import preprocessor_sliders_config, flag_preprocessor_resolution, model_free_preprocessors
|
||||
|
||||
gradio_compat = True
|
||||
try:
|
||||
|
|
@ -50,7 +47,6 @@ except ImportError:
|
|||
svgsupport = False
|
||||
try:
|
||||
import io
|
||||
import base64
|
||||
from svglib.svglib import svg2rlg
|
||||
from reportlab.graphics import renderPM
|
||||
svgsupport = True
|
||||
|
|
@ -163,7 +159,7 @@ def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]:
|
|||
|
||||
if isinstance(image['image'], str):
|
||||
if os.path.exists(image['image']):
|
||||
image['image'] = numpy.array(Image.open(image['image'])).astype('uint8')
|
||||
image['image'] = np.array(Image.open(image['image'])).astype('uint8')
|
||||
elif image['image']:
|
||||
image['image'] = external_code.to_base64_nparray(image['image'])
|
||||
else:
|
||||
|
|
@ -176,7 +172,7 @@ def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]:
|
|||
|
||||
if isinstance(image['mask'], str):
|
||||
if os.path.exists(image['mask']):
|
||||
image['mask'] = numpy.array(Image.open(image['mask'])).astype('uint8')
|
||||
image['mask'] = np.array(Image.open(image['mask'])).astype('uint8')
|
||||
elif image['mask']:
|
||||
image['mask'] = external_code.to_base64_nparray(image['mask'])
|
||||
else:
|
||||
|
|
@ -369,19 +365,17 @@ class Script(scripts.Script):
|
|||
guidance_end = gr.Slider(label="Ending Control Step", value=default_unit.guidance_end, minimum=0.0, maximum=1.0, interactive=True, elem_id=f'{elem_id_tabname}_{tabname}_controlnet_ending_control_step_slider')
|
||||
|
||||
def build_sliders(module, pp):
|
||||
grs = []
|
||||
module = self.get_module_basename(module)
|
||||
if module not in preprocessor_sliders_config:
|
||||
return [
|
||||
grs += [
|
||||
gr.update(label=flag_preprocessor_resolution, value=512, minimum=64, maximum=2048, step=1, visible=not pp, interactive=not pp),
|
||||
gr.update(visible=False, interactive=False),
|
||||
gr.update(visible=False, interactive=False),
|
||||
gr.update(visible=True)
|
||||
]
|
||||
else:
|
||||
slider_configs = preprocessor_sliders_config[module]
|
||||
grs = []
|
||||
|
||||
for slider_config in slider_configs:
|
||||
for slider_config in preprocessor_sliders_config[module]:
|
||||
if isinstance(slider_config, dict):
|
||||
visible = True
|
||||
if slider_config['name'] == flag_preprocessor_resolution:
|
||||
|
|
@ -396,12 +390,14 @@ class Script(scripts.Script):
|
|||
interactive=visible))
|
||||
else:
|
||||
grs.append(gr.update(visible=False, interactive=False))
|
||||
|
||||
while len(grs) < 3:
|
||||
grs.append(gr.update(visible=False, interactive=False))
|
||||
|
||||
grs.append(gr.update(visible=True))
|
||||
return grs
|
||||
if module in model_free_preprocessors:
|
||||
grs += [gr.update(visible=False, value='None'), gr.update(visible=False)]
|
||||
else:
|
||||
grs += [gr.update(visible=True), gr.update(visible=True)]
|
||||
return grs
|
||||
|
||||
# advanced options
|
||||
with gr.Column(visible=False) as advanced:
|
||||
|
|
@ -410,8 +406,8 @@ class Script(scripts.Script):
|
|||
threshold_b = gr.Slider(label="Threshold B", value=default_unit.threshold_b, minimum=64, maximum=1024, visible=False, interactive=False, elem_id=f'{elem_id_tabname}_{tabname}_controlnet_threshold_B_slider')
|
||||
|
||||
if gradio_compat:
|
||||
module.change(build_sliders, inputs=[module, pixel_perfect], outputs=[processor_res, threshold_a, threshold_b, advanced])
|
||||
pixel_perfect.change(build_sliders, inputs=[module, pixel_perfect], outputs=[processor_res, threshold_a, threshold_b, advanced])
|
||||
module.change(build_sliders, inputs=[module, pixel_perfect], outputs=[processor_res, threshold_a, threshold_b, advanced, model, refresh_models])
|
||||
pixel_perfect.change(build_sliders, inputs=[module, pixel_perfect], outputs=[processor_res, threshold_a, threshold_b, advanced, model, refresh_models])
|
||||
|
||||
# infotext_fields.extend((module, model, weight))
|
||||
|
||||
|
|
@ -694,6 +690,9 @@ class Script(scripts.Script):
|
|||
return model_net
|
||||
|
||||
def build_control_model(self, p, unet, model, lowvram):
|
||||
if model is None or model == 'None':
|
||||
raise RuntimeError(f"You have not selected any ControlNet Model.")
|
||||
|
||||
model_path = global_state.cn_models.get(model, None)
|
||||
if model_path is None:
|
||||
model = find_closest_lora_model_name(model)
|
||||
|
|
@ -973,7 +972,10 @@ class Script(scripts.Script):
|
|||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
unet = p.sd_model.model.diffusion_model
|
||||
|
||||
sd_ldm = p.sd_model
|
||||
unet = sd_ldm.model.diffusion_model
|
||||
|
||||
if self.latest_network is not None:
|
||||
# always restore (~0.05s)
|
||||
self.latest_network.restore(unet)
|
||||
|
|
@ -1014,8 +1016,11 @@ class Script(scripts.Script):
|
|||
if unit.low_vram:
|
||||
hook_lowvram = True
|
||||
|
||||
model_net = self.load_control_model(p, unet, unit.model, unit.low_vram)
|
||||
model_net.reset()
|
||||
if unit.module in model_free_preprocessors:
|
||||
model_net = None
|
||||
else:
|
||||
model_net = self.load_control_model(p, unet, unit.model, unit.low_vram)
|
||||
model_net.reset()
|
||||
|
||||
if batch_hijack.instance.is_batch and getattr(p, "image_control", None) is not None:
|
||||
input_image = HWC3(np.asarray(p.image_control))
|
||||
|
|
@ -1201,6 +1206,15 @@ class Script(scripts.Script):
|
|||
if getattr(model_net, "target", None) == "scripts.adapter.StyleAdapter":
|
||||
control_model_type = ControlModelType.T2I_StyleAdapter
|
||||
|
||||
if 'reference' in unit.module:
|
||||
control_model_type = ControlModelType.AttentionInjection
|
||||
|
||||
global_average_pooling = False
|
||||
|
||||
if model_net is not None:
|
||||
if model_net.config.model.params.get("global_average_pooling", False):
|
||||
global_average_pooling = True
|
||||
|
||||
forward_param = ControlParams(
|
||||
control_model=model_net,
|
||||
hint_cond=control,
|
||||
|
|
@ -1210,7 +1224,7 @@ class Script(scripts.Script):
|
|||
stop_guidance_percent=unit.guidance_end,
|
||||
advanced_weighting=None,
|
||||
control_model_type=control_model_type,
|
||||
global_average_pooling=model_net.config.model.params.get("global_average_pooling", False),
|
||||
global_average_pooling=global_average_pooling,
|
||||
hr_hint_cond=hr_control,
|
||||
batch_size=p.batch_size,
|
||||
instance_counter=0,
|
||||
|
|
@ -1224,8 +1238,7 @@ class Script(scripts.Script):
|
|||
del model_net
|
||||
|
||||
self.latest_network = UnetHook(lowvram=hook_lowvram)
|
||||
self.latest_network.hook(unet)
|
||||
self.latest_network.notify(forward_params, is_vanilla_samplers)
|
||||
self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params)
|
||||
self.detected_map = detected_maps
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
version_flag = 'v1.1.152'
|
||||
version_flag = 'v1.1.153'
|
||||
print(f'ControlNet {version_flag}')
|
||||
# A smart trick to know if user has updated as well as if user has restarted terminal.
|
||||
# Note that in "controlnet.py" we do NOT use "importlib.reload" to reload this "controlnet_version.py"
|
||||
|
|
|
|||
|
|
@ -35,11 +35,9 @@ cn_preprocessor_modules = {
|
|||
"pidinet_safe": pidinet_safe,
|
||||
"pidinet_sketch": pidinet_ts,
|
||||
"pidinet_scribble": scribble_pidinet,
|
||||
# "scribble_thr": scribble_thr, # Removed by Lvmin to avoid confusing
|
||||
"scribble_xdog": scribble_xdog,
|
||||
"scribble_hed": scribble_hed,
|
||||
"segmentation": uniformer,
|
||||
# "binary": binary, # Removed by Lvmin to avoid confusing
|
||||
"threshold": threshold,
|
||||
"depth_zoe": zoe_depth,
|
||||
"normal_bae": normal_bae,
|
||||
|
|
@ -51,9 +49,10 @@ cn_preprocessor_modules = {
|
|||
"lineart_standard": lineart_standard,
|
||||
"shuffle": shuffle,
|
||||
"tile_resample": tile_resample,
|
||||
"inpaint": inpaint,
|
||||
"invert": invert,
|
||||
"lineart_anime_denoise": lineart_anime_denoise
|
||||
"lineart_anime_denoise": lineart_anime_denoise,
|
||||
"reference_only": identity,
|
||||
"inpaint": identity,
|
||||
}
|
||||
|
||||
cn_preprocessor_unloadable = {
|
||||
|
|
|
|||
262
scripts/hook.py
262
scripts/hook.py
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import einops
|
||||
import torch.nn as nn
|
||||
|
||||
from enum import Enum
|
||||
|
|
@ -8,6 +9,7 @@ cond_cast_unet = getattr(devices, 'cond_cast_unet', lambda x: x)
|
|||
|
||||
from ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||
from ldm.modules.attention import BasicTransformerBlock
|
||||
|
||||
|
||||
class ControlModelType(Enum):
|
||||
|
|
@ -21,12 +23,22 @@ class ControlModelType(Enum):
|
|||
T2I_CoAdapter = "T2I_CoAdapter, Chong Mou"
|
||||
MasaCtrl = "MasaCtrl, Mingdeng Cao"
|
||||
GLIGEN = "GLIGEN, Yuheng Li"
|
||||
AttentionInjection = "AttentionInjection, Anonymous"
|
||||
AttentionInjection = "AttentionInjection, Lvmin Zhang" # A simple attention injection written by Lvmin
|
||||
StableSR = "StableSR, Jianyi Wang"
|
||||
PromptDiffusion = "PromptDiffusion, Zhendong Wang"
|
||||
ControlLoRA = "ControlLoRA, Wu Hecong"
|
||||
|
||||
|
||||
# Written by Lvmin
|
||||
class AttentionAutoMachine(Enum):
|
||||
"""
|
||||
Lvmin's algorithm for Attention AutoMachine States.
|
||||
"""
|
||||
|
||||
Read = "Read"
|
||||
Write = "Write"
|
||||
|
||||
|
||||
class TorchHijackForUnet:
|
||||
"""
|
||||
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||||
|
|
@ -58,23 +70,23 @@ th = TorchHijackForUnet()
|
|||
|
||||
class ControlParams:
|
||||
def __init__(
|
||||
self,
|
||||
control_model,
|
||||
hint_cond,
|
||||
weight,
|
||||
guidance_stopped,
|
||||
start_guidance_percent,
|
||||
stop_guidance_percent,
|
||||
advanced_weighting,
|
||||
control_model_type,
|
||||
hr_hint_cond,
|
||||
global_average_pooling,
|
||||
batch_size,
|
||||
instance_counter,
|
||||
is_vanilla_samplers,
|
||||
cfg_scale,
|
||||
soft_injection,
|
||||
cfg_injection
|
||||
self,
|
||||
control_model,
|
||||
hint_cond,
|
||||
weight,
|
||||
guidance_stopped,
|
||||
start_guidance_percent,
|
||||
stop_guidance_percent,
|
||||
advanced_weighting,
|
||||
control_model_type,
|
||||
hr_hint_cond,
|
||||
global_average_pooling,
|
||||
batch_size,
|
||||
instance_counter,
|
||||
is_vanilla_samplers,
|
||||
cfg_scale,
|
||||
soft_injection,
|
||||
cfg_injection
|
||||
):
|
||||
self.control_model = control_model
|
||||
self._hint_cond = hint_cond
|
||||
|
|
@ -87,6 +99,7 @@ class ControlParams:
|
|||
self.global_average_pooling = global_average_pooling
|
||||
self.hr_hint_cond = hr_hint_cond
|
||||
self.used_hint_cond = None
|
||||
self.used_hint_cond_latent = None
|
||||
self.batch_size = batch_size
|
||||
self.instance_counter = instance_counter
|
||||
self.is_vanilla_samplers = is_vanilla_samplers
|
||||
|
|
@ -94,8 +107,10 @@ class ControlParams:
|
|||
self.soft_injection = soft_injection
|
||||
self.cfg_injection = cfg_injection
|
||||
|
||||
def generate_uc_mask(self, length, dtype, device):
|
||||
def generate_uc_mask(self, length, dtype=None, device=None, python_list=False):
|
||||
if self.is_vanilla_samplers and self.cfg_scale == 1:
|
||||
if python_list:
|
||||
return [1 for _ in range(length)]
|
||||
return torch.tensor([1 for _ in range(length)], dtype=dtype, device=device)
|
||||
|
||||
y = []
|
||||
|
|
@ -109,6 +124,9 @@ class ControlParams:
|
|||
|
||||
self.instance_counter += length
|
||||
|
||||
if python_list:
|
||||
return y
|
||||
|
||||
return torch.tensor(y, dtype=dtype, device=device)
|
||||
|
||||
@property
|
||||
|
|
@ -123,49 +141,75 @@ class ControlParams:
|
|||
def hint_cond(self, new_hint_cond):
|
||||
self._hint_cond = new_hint_cond
|
||||
self.used_hint_cond = None
|
||||
self.used_hint_cond_latent = None
|
||||
|
||||
|
||||
def aligned_adding(base, x, require_channel_alignment):
|
||||
if isinstance(x, float):
|
||||
if x == 0.0:
|
||||
return base
|
||||
return base + x
|
||||
|
||||
if require_channel_alignment:
|
||||
zeros = torch.zeros_like(base)
|
||||
zeros[:, :x.shape[1], ...] = x
|
||||
x = zeros
|
||||
|
||||
# resize to sample resolution
|
||||
base_h, base_w = base.shape[-2:]
|
||||
xh, xw = x.shape[-2:]
|
||||
if base_h != xh or base_w != xw:
|
||||
x = th.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest")
|
||||
|
||||
return base + x
|
||||
|
||||
|
||||
# DFS Search for Torch.nn.Module, Written by Lvmin
|
||||
def torch_dfs(model: torch.nn.Module):
|
||||
result = [model]
|
||||
for child in model.children():
|
||||
result += torch_dfs(child)
|
||||
return result
|
||||
|
||||
|
||||
class UnetHook(nn.Module):
|
||||
def __init__(self, lowvram=False) -> None:
|
||||
super().__init__()
|
||||
self.lowvram = lowvram
|
||||
self.model = None
|
||||
self.sd_ldm = None
|
||||
self.control_params = None
|
||||
self.attention_auto_machine = AttentionAutoMachine.Read
|
||||
self.attention_auto_machine_uc_mask = None
|
||||
self.attention_auto_machine_weight = 1.0
|
||||
|
||||
def guidance_schedule_handler(self, x):
|
||||
for param in self.control_params:
|
||||
current_sampling_percent = (x.sampling_step / x.total_sampling_steps)
|
||||
param.guidance_stopped = current_sampling_percent < param.start_guidance_percent or current_sampling_percent > param.stop_guidance_percent
|
||||
|
||||
def hook(self, model):
|
||||
def hook(self, model, sd_ldm, control_params):
|
||||
self.model = model
|
||||
self.sd_ldm = sd_ldm
|
||||
self.control_params = control_params
|
||||
|
||||
outer = self
|
||||
|
||||
def cfg_based_adder(base, x, require_autocast):
|
||||
if isinstance(x, float):
|
||||
return base + x
|
||||
|
||||
if require_autocast:
|
||||
zeros = torch.zeros_like(base)
|
||||
zeros[:, :x.shape[1], ...] = x
|
||||
x = zeros
|
||||
|
||||
# resize to sample resolution
|
||||
base_h, base_w = base.shape[-2:]
|
||||
xh, xw = x.shape[-2:]
|
||||
if base_h != xh or base_w != xw:
|
||||
x = th.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest")
|
||||
|
||||
return base + x
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, **kwargs):
|
||||
total_controlnet_embedding = [0.0] * 13
|
||||
total_t2i_adapter_embedding = [0.0] * 4
|
||||
total_extra_cond = None
|
||||
require_inpaint_hijack = False
|
||||
is_in_high_res_fix = False
|
||||
|
||||
# High-res fix
|
||||
is_in_high_res_fix = False
|
||||
for param in outer.control_params:
|
||||
# select which hint_cond to use
|
||||
param.used_hint_cond = param.hint_cond
|
||||
|
||||
# Attention Injection do not need high-res fix
|
||||
if param.control_model_type in [ControlModelType.AttentionInjection]:
|
||||
continue
|
||||
|
||||
# has high-res fix
|
||||
if param.hr_hint_cond is not None and x.ndim == 4 and param.hint_cond.ndim == 3 and param.hr_hint_cond.ndim == 3:
|
||||
_, h_lr, w_lr = param.hint_cond.shape
|
||||
|
|
@ -180,6 +224,19 @@ class UnetHook(nn.Module):
|
|||
param.used_hint_cond = param.hr_hint_cond
|
||||
is_in_high_res_fix = True
|
||||
|
||||
# Convert control image to latent
|
||||
for param in outer.control_params:
|
||||
if param.used_hint_cond_latent is not None:
|
||||
continue
|
||||
if param.control_model_type not in [ControlModelType.AttentionInjection]:
|
||||
continue
|
||||
query_size = int(x.shape[0])
|
||||
latent_hint = param.used_hint_cond[None] * 2.0 - 1.0
|
||||
latent_hint = outer.sd_ldm.encode_first_stage(latent_hint)
|
||||
latent_hint = outer.sd_ldm.get_first_stage_encoding(latent_hint)
|
||||
latent_hint = torch.cat([latent_hint.clone() for _ in range(query_size)], dim=0)
|
||||
param.used_hint_cond_latent = latent_hint
|
||||
|
||||
# handle prompt token control
|
||||
for param in outer.control_params:
|
||||
if param.guidance_stopped:
|
||||
|
|
@ -195,14 +252,8 @@ class UnetHook(nn.Module):
|
|||
control = torch.cat([control.clone() for _ in range(query_size)], dim=0)
|
||||
control *= param.weight
|
||||
control *= uc_mask
|
||||
if total_extra_cond is None:
|
||||
total_extra_cond = control.clone()
|
||||
else:
|
||||
total_extra_cond = torch.cat([total_extra_cond, control.clone()], dim=1)
|
||||
|
||||
if total_extra_cond is not None:
|
||||
context = torch.cat([context, total_extra_cond], dim=1)
|
||||
|
||||
context = torch.cat([context, control.clone()], dim=1)
|
||||
|
||||
# handle ControlNet / T2I_Adapter
|
||||
for param in outer.control_params:
|
||||
if param.guidance_stopped:
|
||||
|
|
@ -225,7 +276,7 @@ class UnetHook(nn.Module):
|
|||
assert param.used_hint_cond is not None, f"Controlnet is enabled but no input image is given"
|
||||
control = param.control_model(x=x_in, hint=param.used_hint_cond, timesteps=timesteps, context=context)
|
||||
control_scales = ([param.weight] * 13)
|
||||
|
||||
|
||||
if outer.lowvram:
|
||||
param.control_model.to("cpu")
|
||||
|
||||
|
|
@ -245,11 +296,11 @@ class UnetHook(nn.Module):
|
|||
|
||||
if param.advanced_weighting is not None:
|
||||
control_scales = param.advanced_weighting
|
||||
|
||||
|
||||
control = [c * scale for c, scale in zip(control, control_scales)]
|
||||
if param.global_average_pooling:
|
||||
control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control]
|
||||
|
||||
|
||||
for idx, item in enumerate(control):
|
||||
target = None
|
||||
if param.control_model_type == ControlModelType.ControlNet:
|
||||
|
|
@ -258,8 +309,29 @@ class UnetHook(nn.Module):
|
|||
target = total_t2i_adapter_embedding
|
||||
if target is not None:
|
||||
target[idx] = item + target[idx]
|
||||
|
||||
assert timesteps is not None, ValueError(f"insufficient timestep: {timesteps}")
|
||||
|
||||
# Handle attention-based control
|
||||
for param in outer.control_params:
|
||||
if param.guidance_stopped:
|
||||
continue
|
||||
|
||||
if param.used_hint_cond_latent is None:
|
||||
continue
|
||||
|
||||
if param.control_model_type not in [ControlModelType.AttentionInjection]:
|
||||
continue
|
||||
|
||||
query_size = int(x.shape[0])
|
||||
ref_xt = outer.sd_ldm.q_sample(param.used_hint_cond_latent, torch.round(timesteps).long())
|
||||
outer.attention_auto_machine_uc_mask = param.generate_uc_mask(query_size, python_list=True)
|
||||
if param.soft_injection:
|
||||
outer.attention_auto_machine_uc_mask = [1 for _ in outer.attention_auto_machine_uc_mask]
|
||||
outer.attention_auto_machine_weight = param.weight
|
||||
outer.attention_auto_machine = AttentionAutoMachine.Write
|
||||
outer.original_forward(x=ref_xt, timesteps=timesteps, context=context)
|
||||
outer.attention_auto_machine = AttentionAutoMachine.Read
|
||||
|
||||
# U-Net Encoder
|
||||
hs = []
|
||||
with th.no_grad():
|
||||
t_emb = cond_cast_unet(timestep_embedding(timesteps, self.model_channels, repeat_only=False))
|
||||
|
|
@ -267,50 +339,100 @@ class UnetHook(nn.Module):
|
|||
h = x.type(self.dtype)
|
||||
for i, module in enumerate(self.input_blocks):
|
||||
h = module(h, emb, context)
|
||||
|
||||
# t2i-adatper, same as openaimodel.py:744
|
||||
if ((i+1) % 3 == 0) and len(total_t2i_adapter_embedding) > 0:
|
||||
h = cfg_based_adder(h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack)
|
||||
|
||||
if (i + 1) % 3 == 0:
|
||||
h = aligned_adding(h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack)
|
||||
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
|
||||
h = cfg_based_adder(h, total_controlnet_embedding.pop(), require_inpaint_hijack)
|
||||
# U-Net Middle Block
|
||||
h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack)
|
||||
|
||||
# U-Net Decoder
|
||||
for i, module in enumerate(self.output_blocks):
|
||||
h = th.cat([h, cfg_based_adder(hs.pop(), total_controlnet_embedding.pop(), require_inpaint_hijack)], dim=1)
|
||||
h = th.cat([h, aligned_adding(hs.pop(), total_controlnet_embedding.pop(), require_inpaint_hijack)], dim=1)
|
||||
h = module(h, emb, context)
|
||||
|
||||
# U-Net Output
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
h = self.out(h)
|
||||
|
||||
def forward2(*args, **kwargs):
|
||||
return h
|
||||
|
||||
def forward_webui(*args, **kwargs):
|
||||
# webui will handle other compoments
|
||||
try:
|
||||
if shared.cmd_opts.lowvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
|
||||
|
||||
return forward(*args, **kwargs)
|
||||
finally:
|
||||
if self.lowvram:
|
||||
[param.control_model.to("cpu") for param in self.control_params]
|
||||
|
||||
for param in self.control_params:
|
||||
if param.control_model is not None:
|
||||
param.control_model.to("cpu")
|
||||
|
||||
def hacked_basic_transformer_inner_forward(self, x, context=None):
|
||||
x_norm1 = self.norm1(x)
|
||||
self_attn1 = 0
|
||||
if self.disable_self_attn:
|
||||
# Do not use self-attention
|
||||
self_attn1 = self.attn1(x_norm1, context=context)
|
||||
else:
|
||||
# Use self-attention
|
||||
self_attention_context = x_norm1
|
||||
if outer.attention_auto_machine == AttentionAutoMachine.Write:
|
||||
uc_mask = outer.attention_auto_machine_uc_mask
|
||||
control_weight = outer.attention_auto_machine_weight
|
||||
store = []
|
||||
for i, mask in enumerate(uc_mask):
|
||||
if mask > 0.5 and control_weight > self.attn_weight:
|
||||
store.append(self_attention_context[i])
|
||||
else:
|
||||
store.append(None)
|
||||
self.bank.append(store)
|
||||
self_attn1 = self.attn1(x_norm1, context=self_attention_context)
|
||||
if outer.attention_auto_machine == AttentionAutoMachine.Read:
|
||||
query_size = self_attention_context.shape[0]
|
||||
self_attention_context = [self_attention_context[i] for i in range(query_size)]
|
||||
for store in self.bank:
|
||||
for i, v in enumerate(store):
|
||||
if v is not None:
|
||||
self_attention_context[i] = torch.cat([self_attention_context[i], v], dim=0)
|
||||
x_norm1 = [x_norm1[i] for i in range(query_size)]
|
||||
self_attn1 = [self.attn1(a[None], context=b[None]) for a, b in zip(x_norm1, self_attention_context)]
|
||||
self_attn1 = torch.cat(self_attn1, dim=0)
|
||||
self.bank.clear()
|
||||
|
||||
x = self_attn1 + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
model._original_forward = model.forward
|
||||
model.forward = forward2.__get__(model, UNetModel)
|
||||
outer.original_forward = model.forward
|
||||
model.forward = forward_webui.__get__(model, UNetModel)
|
||||
|
||||
attn_modules = [module for module in torch_dfs(model) if isinstance(module, BasicTransformerBlock)]
|
||||
attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0])
|
||||
|
||||
for i, module in enumerate(attn_modules):
|
||||
module._original_inner_forward = module._forward
|
||||
module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
|
||||
module.bank = []
|
||||
module.attn_weight = float(i) / float(len(attn_modules))
|
||||
|
||||
scripts.script_callbacks.on_cfg_denoiser(self.guidance_schedule_handler)
|
||||
|
||||
def notify(self, params, is_vanilla_samplers): # lint: list[ControlParams]
|
||||
self.is_vanilla_samplers = is_vanilla_samplers
|
||||
self.control_params = params
|
||||
|
||||
def restore(self, model):
|
||||
scripts.script_callbacks.remove_callbacks_for_function(self.guidance_schedule_handler)
|
||||
if hasattr(self, "control_params"):
|
||||
del self.control_params
|
||||
|
||||
|
||||
if not hasattr(model, "_original_forward"):
|
||||
# no such handle, ignore
|
||||
return
|
||||
|
||||
|
||||
model.forward = model._original_forward
|
||||
del model._original_forward
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ def threshold(img, res=512, thr_a=127, **kwargs):
|
|||
return remove_pad(result), True
|
||||
|
||||
|
||||
def inpaint(img, res=512, **kwargs):
|
||||
def identity(img, **kwargs):
|
||||
return img, True
|
||||
|
||||
|
||||
|
|
@ -553,11 +553,12 @@ def shuffle(img, res=512, **kwargs):
|
|||
return result, True
|
||||
|
||||
|
||||
model_free_preprocessors = ["reference_only"]
|
||||
flag_preprocessor_resolution = "Preprocessor Resolution"
|
||||
preprocessor_sliders_config = {
|
||||
"none": [
|
||||
|
||||
],
|
||||
"none": [],
|
||||
"inpaint": [],
|
||||
"reference_only": [],
|
||||
"canny": [
|
||||
{
|
||||
"name": flag_preprocessor_resolution,
|
||||
|
|
|
|||
Loading…
Reference in New Issue