Support PuLID (#2838)

* Add preprocessors

* Fix resolution param

* Fix various issues

* Add PuLID attn

* remove unused import

* Resize img before passing to facexlib

* safe unload
pull/2842/head
Chenlei Hu 2024-05-04 12:25:00 -04:00 committed by GitHub
parent 36a310f599
commit 784b6d01a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 571 additions and 80 deletions

View File

@ -11,7 +11,7 @@ from modules import scripts, processing, shared
from modules.safe import unsafe_torch_load
from scripts import global_state
from scripts.logging import logger
from scripts.enums import HiResFixOption
from scripts.enums import HiResFixOption, PuLIDMode
from scripts.supported_preprocessor import Preprocessor, PreprocessorParameter
from modules.api import api
@ -207,6 +207,10 @@ class ControlNetUnit:
# The effective region mask that unit's effect should be restricted to.
effective_region_mask: Optional[np.ndarray] = None
# The weight mode for PuLID.
# https://github.com/ToTheBeginning/PuLID
pulid_mode: PuLIDMode = PuLIDMode.FIDELITY
# The tensor input for ipadapter. When this field is set in the API,
# the base64string will be interpret by torch.load to reconstruct ipadapter
# preprocessor output.
@ -243,6 +247,7 @@ class ControlNetUnit:
# provide much information when restoring the unit.
"inpaint_crop_input_image",
"effective_region_mask",
"pulid_mode",
]
@property

View File

@ -6,3 +6,4 @@ addict
yapf
albumentations==1.4.3
matplotlib
facexlib

View File

@ -179,7 +179,7 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
low_vram=low_vram,
)
if preprocessor.returns_image:
images.append(encode_to_base64(result.display_image))
images.append(encode_to_base64(result.display_images[0]))
else:
tensors.append(encode_tensor_to_base64(result.value))

View File

@ -17,12 +17,14 @@ from einops import rearrange
import scripts.preprocessor as preprocessor_init # noqa
from annotator.util import HWC3
from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils
from internal_controlnet.external_code import ControlMode
from scripts.controlnet_lora import bind_control_lora, unbind_control_lora
from scripts.controlnet_lllite import clear_all_lllite
from scripts.ipadapter.plugable_ipadapter import ImageEmbed, clear_all_ip_adapter
from scripts.ipadapter.pulid_attn import PULID_SETTING_FIDELITY, PULID_SETTING_STYLE
from scripts.utils import load_state_dict, get_unique_axis0, align_dim_latent
from scripts.hook import ControlParams, UnetHook, HackedImageRNG
from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption
from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption, PuLIDMode
from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
from scripts.controlnet_ui.photopea import Photopea
from scripts.logging import logger
@ -279,6 +281,7 @@ def get_control(
)
detected_map = result.value
is_image = preprocessor.returns_image
# TODO: Refactor img control detection logic.
if high_res_fix:
if is_image:
hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x)
@ -293,7 +296,8 @@ def get_control(
store_detected_map(detected_map, unit.module)
else:
control = detected_map
store_detected_map(input_image, unit.module)
for image in result.display_images:
store_detected_map(image, unit.module)
if control_model_type == ControlModelType.T2I_StyleAdapter:
control = control['last_hidden_state']
@ -1092,8 +1096,8 @@ class Script(scripts.Script, metaclass=(
global_average_pooling=global_average_pooling,
hr_hint_cond=hr_control,
hr_option=HiResFixOption.from_value(unit.hr_option) if high_res_fix else HiResFixOption.BOTH,
soft_injection=control_mode != external_code.ControlMode.BALANCED,
cfg_injection=control_mode == external_code.ControlMode.CONTROL,
soft_injection=control_mode != ControlMode.BALANCED,
cfg_injection=control_mode == ControlMode.CONTROL,
effective_region_mask=(
get_pytorch_control(unit.effective_region_mask)[:, 0:1, :, :]
if unit.effective_region_mask is not None
@ -1190,7 +1194,7 @@ class Script(scripts.Script, metaclass=(
is_low_vram = any(unit.low_vram for unit in self.enabled_units)
for i, param in enumerate(forward_params):
for i, (param, unit) in enumerate(zip(forward_params, self.enabled_units)):
if param.control_model_type == ControlModelType.IPAdapter:
if param.advanced_weighting is not None:
logger.info(f"IP-Adapter using advanced weighting {param.advanced_weighting}")
@ -1205,6 +1209,13 @@ class Script(scripts.Script, metaclass=(
weight = param.weight
h, w, hr_y, hr_x = Script.get_target_dimensions(p)
pulid_mode = PuLIDMode(unit.pulid_mode)
if pulid_mode == PuLIDMode.STYLE:
pulid_attn_setting = PULID_SETTING_STYLE
else:
assert pulid_mode == PuLIDMode.FIDELITY
pulid_attn_setting = PULID_SETTING_FIDELITY
param.control_model.hook(
model=unet,
preprocessor_outputs=param.hint_cond,
@ -1215,6 +1226,7 @@ class Script(scripts.Script, metaclass=(
latent_width=w // 8,
latent_height=h // 8,
effective_region_mask=param.effective_region_mask,
pulid_attn_setting=pulid_attn_setting,
)
if param.control_model_type == ControlModelType.Controlllite:
param.control_model.hook(

View File

@ -18,7 +18,7 @@ from scripts.controlnet_ui.openpose_editor import OpenposeEditor
from scripts.controlnet_ui.preset import ControlNetPresetUI
from scripts.controlnet_ui.photopea import Photopea
from scripts.controlnet_ui.advanced_weight_control import AdvancedWeightControl
from scripts.enums import InputMode
from scripts.enums import InputMode, PuLIDMode
from modules import shared
from modules.ui_components import FormRow, FormHTML, ToolButton
@ -287,6 +287,7 @@ class ControlNetUiGroup(object):
self.batch_image_dir_state = None
self.output_dir_state = None
self.advanced_weighting = gr.State(None)
self.pulid_mode = None
# API-only fields
self.ipadapter_input = gr.State(None)
@ -626,6 +627,15 @@ class ControlNetUiGroup(object):
visible=False,
)
self.pulid_mode = gr.Radio(
choices=[e.value for e in PuLIDMode],
value=self.default_unit.pulid_mode.value,
label="PuLID Mode",
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_pulid_mode_radio",
elem_classes="controlnet_pulid_mode_radio",
visible=False,
)
self.loopback = gr.Checkbox(
label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation",
value=self.default_unit.loopback,
@ -673,6 +683,7 @@ class ControlNetUiGroup(object):
self.save_detected_map,
self.advanced_weighting,
self.effective_region_mask,
self.pulid_mode,
)
unit = gr.State(self.default_unit)
@ -947,7 +958,7 @@ class ControlNetUiGroup(object):
return (
# Update to `generated_image`
gr.update(value=result.display_image, visible=True, interactive=False),
gr.update(value=result.display_images[0], visible=True, interactive=False),
# preprocessor_preview
gr.update(value=True),
# openpose editor
@ -1118,6 +1129,14 @@ class ControlNetUiGroup(object):
show_progress=False,
)
def register_shift_pulid_mode(self):
self.model.change(
fn=lambda model: gr.update(visible="pulid" in model.lower()),
inputs=[self.model],
outputs=[self.pulid_mode],
show_progress=False,
)
def register_sync_batch_dir(self):
def determine_batch_dir(batch_dir, fallback_dir, fallback_fallback_dir):
if batch_dir:
@ -1220,6 +1239,7 @@ class ControlNetUiGroup(object):
self.register_build_sliders()
self.register_shift_preview()
self.register_shift_upload_mask()
self.register_shift_pulid_mode()
self.register_create_canvas()
self.register_clear_preview()
self.register_multi_images_upload()

View File

@ -247,3 +247,8 @@ class InputMode(Enum):
# Input is a directory. 1 generation. Each generation takes N input image
# from the directory.
MERGE = "merge"
class PuLIDMode(Enum):
FIDELITY = "Fidelity"
STYLE = "Extremely style"

View File

@ -269,3 +269,65 @@ class Resampler(nn.Module):
latents = self.proj_out(latents)
return self.norm_out(latents)
class PuLIDEncoder(nn.Module):
def __init__(self, width=1280, context_dim=2048, num_token=5):
super().__init__()
self.num_token = num_token
self.context_dim = context_dim
h1 = min((context_dim * num_token) // 4, 1024)
h2 = min((context_dim * num_token) // 2, 1024)
self.body = nn.Sequential(
nn.Linear(width, h1),
nn.LayerNorm(h1),
nn.LeakyReLU(),
nn.Linear(h1, h2),
nn.LayerNorm(h2),
nn.LeakyReLU(),
nn.Linear(h2, context_dim * num_token),
)
for i in range(5):
setattr(
self,
f"mapping_{i}",
nn.Sequential(
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, context_dim),
),
)
setattr(
self,
f"mapping_patch_{i}",
nn.Sequential(
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, context_dim),
),
)
def forward(self, x, y):
# x shape [N, C]
x = self.body(x)
x = x.reshape(-1, self.num_token, self.context_dim)
hidden_states = ()
for i, emb in enumerate(y):
hidden_state = getattr(self, f"mapping_{i}")(emb[:, :1]) + getattr(
self, f"mapping_patch_{i}"
)(emb[:, 1:]).mean(dim=1, keepdim=True)
hidden_states += (hidden_state,)
hidden_states = torch.cat(hidden_states, dim=1)
return torch.cat([x, hidden_states], dim=1)

View File

@ -12,6 +12,7 @@ from .image_proj_models import (
MLPProjModel,
MLPProjModelFaceId,
ProjModelFaceIdPlus,
PuLIDEncoder,
)
@ -71,6 +72,7 @@ class IPAdapterModel(torch.nn.Module):
is_faceid: bool,
is_portrait: bool,
is_instantid: bool,
is_pulid: bool,
is_v2: bool,
):
super().__init__()
@ -85,9 +87,12 @@ class IPAdapterModel(torch.nn.Module):
self.is_v2 = is_v2
self.is_faceid = is_faceid
self.is_instantid = is_instantid
self.is_pulid = is_pulid
self.clip_extra_context_tokens = 16 if (self.is_plus or is_portrait) else 4
if is_instantid:
if self.is_pulid:
self.image_proj_model = PuLIDEncoder()
elif self.is_instantid:
self.image_proj_model = self.init_proj_instantid()
elif is_faceid:
self.image_proj_model = self.init_proj_faceid()
@ -235,6 +240,34 @@ class IPAdapterModel(torch.nn.Module):
self.image_proj_model(torch.zeros_like(prompt_image_emb)),
)
def _get_image_embeds_pulid(self, pulid_proj_input) -> ImageEmbed:
"""Get image embeds for pulid."""
id_cond = torch.cat(
[
pulid_proj_input.id_ante_embedding.to(
device=self.device, dtype=torch.float32
),
pulid_proj_input.id_cond_vit.to(
device=self.device, dtype=torch.float32
),
],
dim=-1,
)
id_vit_hidden = [
t.to(device=self.device, dtype=torch.float32)
for t in pulid_proj_input.id_vit_hidden
]
return ImageEmbed(
self.image_proj_model(
id_cond,
id_vit_hidden,
),
self.image_proj_model(
torch.zeros_like(id_cond),
[torch.zeros_like(t) for t in id_vit_hidden],
),
)
@staticmethod
def load(state_dict: dict, model_name: str) -> IPAdapterModel:
"""
@ -245,6 +278,7 @@ class IPAdapterModel(torch.nn.Module):
is_v2 = "v2" in model_name
is_faceid = "faceid" in model_name
is_instantid = "instant_id" in model_name
is_pulid = "pulid" in model_name.lower()
is_portrait = "portrait" in model_name
is_full = "proj.3.weight" in state_dict["image_proj"]
is_plus = (
@ -256,8 +290,8 @@ class IPAdapterModel(torch.nn.Module):
sdxl = cross_attention_dim == 2048
sdxl_plus = sdxl and is_plus
if is_instantid:
# InstantID does not use clip embedding.
if is_instantid or is_pulid:
# InstantID/PuLID does not use clip embedding.
clip_embeddings_dim = None
elif is_faceid:
if is_plus:
@ -291,10 +325,13 @@ class IPAdapterModel(torch.nn.Module):
is_portrait=is_portrait,
is_instantid=is_instantid,
is_v2=is_v2,
is_pulid=is_pulid,
)
def get_image_emb(self, preprocessor_output) -> ImageEmbed:
if self.is_instantid:
if self.is_pulid:
return self._get_image_embeds_pulid(preprocessor_output)
elif self.is_instantid:
return self._get_image_embeds_instantid(preprocessor_output)
elif self.is_faceid and self.is_plus:
# Note: FaceID plus uses both face_embed and clip_embed.

View File

@ -1,8 +1,9 @@
import itertools
import torch
import math
from typing import Union, Dict, Optional
from typing import Union, Dict, Optional, Callable
from .pulid_attn import PuLIDAttnSetting
from .ipadapter_model import ImageEmbed, IPAdapterModel
from ..enums import StableDiffusionVersion, TransformerID
@ -93,7 +94,7 @@ def clear_all_ip_adapter():
class PlugableIPAdapter(torch.nn.Module):
def __init__(self, ipadapter: IPAdapterModel):
super().__init__()
self.ipadapter = ipadapter
self.ipadapter: IPAdapterModel = ipadapter
self.disable_memory_management = True
self.dtype = None
self.weight: Union[float, Dict[int, float]] = 1.0
@ -103,6 +104,7 @@ class PlugableIPAdapter(torch.nn.Module):
self.latent_width: int = 0
self.latent_height: int = 0
self.effective_region_mask = None
self.pulid_attn_setting: Optional[PuLIDAttnSetting] = None
def reset(self):
self.cache = {}
@ -118,6 +120,7 @@ class PlugableIPAdapter(torch.nn.Module):
latent_width: int,
latent_height: int,
effective_region_mask: Optional[torch.Tensor],
pulid_attn_setting: Optional[PuLIDAttnSetting] = None,
dtype=torch.float32,
):
global current_model
@ -128,6 +131,7 @@ class PlugableIPAdapter(torch.nn.Module):
self.latent_width = latent_width
self.latent_height = latent_height
self.effective_region_mask = effective_region_mask
self.pulid_attn_setting = pulid_attn_setting
self.cache = {}
@ -186,7 +190,9 @@ class PlugableIPAdapter(torch.nn.Module):
# sequence_length = (latent_height * factor) * (latent_height * factor)
# sequence_length = (latent_height * latent_height) * factor ^ 2
factor = math.sqrt(sequence_length / (self.latent_width * self.latent_height))
assert factor > 0, f"{factor}, {sequence_length}, {self.latent_width}, {self.latent_height}"
assert (
factor > 0
), f"{factor}, {sequence_length}, {self.latent_width}, {self.latent_height}"
mask_h = int(self.latent_height * factor)
mask_w = int(self.latent_width * factor)
@ -199,6 +205,71 @@ class PlugableIPAdapter(torch.nn.Module):
mask = mask.view(mask.shape[0], -1, 1).repeat(1, 1, out.shape[2])
return out * mask
def attn_eval(
self,
hidden_states: torch.Tensor,
query: torch.Tensor,
cond_uncond_image_emb: torch.Tensor,
attn_heads: int,
head_dim: int,
emb_to_k: Callable[[torch.Tensor], torch.Tensor],
emb_to_v: Callable[[torch.Tensor], torch.Tensor],
):
if self.ipadapter.is_pulid:
assert self.pulid_attn_setting is not None
return self.pulid_attn_setting.eval(
hidden_states,
query,
cond_uncond_image_emb,
attn_heads,
head_dim,
emb_to_k,
emb_to_v,
)
else:
return self._attn_eval_ipadapter(
hidden_states,
query,
cond_uncond_image_emb,
attn_heads,
head_dim,
emb_to_k,
emb_to_v,
)
def _attn_eval_ipadapter(
self,
hidden_states: torch.Tensor,
query: torch.Tensor,
cond_uncond_image_emb: torch.Tensor,
attn_heads: int,
head_dim: int,
emb_to_k: Callable[[torch.Tensor], torch.Tensor],
emb_to_v: Callable[[torch.Tensor], torch.Tensor],
):
assert hidden_states.ndim == 3
batch_size, sequence_length, inner_dim = hidden_states.shape
ip_k = emb_to_k(cond_uncond_image_emb)
ip_v = emb_to_v(cond_uncond_image_emb)
ip_k, ip_v = map(
lambda t: t.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2),
(ip_k, ip_v),
)
assert ip_k.dtype == ip_v.dtype
# On MacOS, q can be float16 instead of float32.
# https://github.com/Mikubill/sd-webui-controlnet/issues/2208
if query.dtype != ip_k.dtype:
ip_k = ip_k.to(dtype=query.dtype)
ip_v = ip_v.to(dtype=query.dtype)
ip_out = torch.nn.functional.scaled_dot_product_attention(
query, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
return ip_out
@torch.no_grad()
def patch_forward(self, number: int, transformer_index: int):
@torch.no_grad()
@ -220,27 +291,15 @@ class PlugableIPAdapter(torch.nn.Module):
k_key = f"{number * 2 + 1}_to_k_ip"
v_key = f"{number * 2 + 1}_to_v_ip"
cond_uncond_image_emb = self.image_emb.eval(current_model.cond_mark)
ip_k = self.call_ip(k_key, cond_uncond_image_emb, device=q.device)
ip_v = self.call_ip(v_key, cond_uncond_image_emb, device=q.device)
ip_k, ip_v = map(
lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2),
(ip_k, ip_v),
ip_out = self.attn_eval(
hidden_states=x,
query=q,
cond_uncond_image_emb=self.image_emb.eval(current_model.cond_mark),
attn_heads=h,
head_dim=head_dim,
emb_to_k=lambda emb: self.call_ip(k_key, emb, device=q.device),
emb_to_v=lambda emb: self.call_ip(v_key, emb, device=q.device),
)
assert ip_k.dtype == ip_v.dtype
# On MacOS, q can be float16 instead of float32.
# https://github.com/Mikubill/sd-webui-controlnet/issues/2208
if q.dtype != ip_k.dtype:
ip_k = ip_k.to(dtype=q.dtype)
ip_v = ip_v.to(dtype=q.dtype)
ip_out = torch.nn.functional.scaled_dot_product_attention(
q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
return self.apply_effective_region_mask(ip_out * weight)
return forward

View File

@ -166,6 +166,12 @@ ipadapter_presets: List[IPAdapterPreset] = [
model="ip-adapter-faceid-portrait_sdxl",
sd_version=StableDiffusionVersion.SDXL,
),
IPAdapterPreset(
name="pulid",
module="ip-adapter_pulid",
model="ip-adapter_pulid_sdxl_fp16",
sd_version=StableDiffusionVersion.SDXL,
),
]
_preset_by_model = {p.model: p for p in ipadapter_presets}

View File

@ -0,0 +1,94 @@
import torch
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Callable
@dataclass
class PuLIDAttnSetting:
num_zero: int = 0
ortho: bool = False
ortho_v2: bool = False
def eval(
self,
hidden_states: torch.Tensor,
query: torch.Tensor,
id_embedding: torch.Tensor,
attn_heads: int,
head_dim: int,
id_to_k: Callable[[torch.Tensor], torch.Tensor],
id_to_v: Callable[[torch.Tensor], torch.Tensor],
):
assert hidden_states.ndim == 3
batch_size, sequence_length, inner_dim = hidden_states.shape
if self.num_zero == 0:
id_key = id_to_k(id_embedding).to(query.dtype)
id_value = id_to_v(id_embedding).to(query.dtype)
else:
zero_tensor = torch.zeros(
(id_embedding.size(0), self.num_zero, id_embedding.size(-1)),
dtype=id_embedding.dtype,
device=id_embedding.device,
)
id_key = id_to_k(torch.cat((id_embedding, zero_tensor), dim=1)).to(
query.dtype
)
id_value = id_to_v(torch.cat((id_embedding, zero_tensor), dim=1)).to(
query.dtype
)
id_key = id_key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
id_value = id_value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
id_hidden_states = F.scaled_dot_product_attention(
query, id_key, id_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
id_hidden_states = id_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn_heads * head_dim
)
id_hidden_states = id_hidden_states.to(query.dtype)
if not self.ortho and not self.ortho_v2:
return id_hidden_states
elif self.ortho_v2:
orig_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
id_hidden_states = id_hidden_states.to(torch.float32)
attn_map = query @ id_key.transpose(-2, -1)
attn_mean = attn_map.softmax(dim=-1).mean(dim=1)
attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True)
projection = (
torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
/ torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
* hidden_states
)
orthogonal = id_hidden_states + (attn_mean - 1) * projection
return orthogonal.to(orig_dtype)
else:
orig_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
id_hidden_states = id_hidden_states.to(torch.float32)
projection = (
torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
/ torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
* hidden_states
)
orthogonal = id_hidden_states - projection
return orthogonal.to(orig_dtype)
PULID_SETTING_FIDELITY = PuLIDAttnSetting(
num_zero=8,
ortho=False,
ortho_v2=True,
)
PULID_SETTING_STYLE = PuLIDAttnSetting(
num_zero=16,
ortho=True,
ortho_v2=False,
)

View File

@ -1,3 +1,4 @@
from .pulid import *
from .inpaint import *
from .lama_inpaint import *
from .ip_adapter_auto import *

View File

@ -1,18 +1,7 @@
import numpy as np
from scripts.utils import visualize_inpaint_mask
from ..supported_preprocessor import Preprocessor, PreprocessorParameter
def visualize_inpaint_mask(img):
if img.ndim == 3 and img.shape[2] == 4:
result = img.copy()
mask = result[:, :, 3]
mask = 255 - mask // 2
result[:, :, 3] = mask
return np.ascontiguousarray(result.copy())
return img
class PreprocessorInpaint(Preprocessor):
def __init__(self):
super().__init__(name="inpaint")
@ -23,9 +12,6 @@ class PreprocessorInpaint(Preprocessor):
self.accepts_mask = True
self.requires_mask = True
def get_display_image(self, input_image: np.ndarray, result):
return visualize_inpaint_mask(result)
def __call__(
self,
input_image,
@ -35,7 +21,10 @@ class PreprocessorInpaint(Preprocessor):
slider_3=None,
**kwargs
):
return input_image
return Preprocessor.Result(
value=input_image,
display_images=visualize_inpaint_mask(input_image)[None, :, :, :],
)
class PreprocessorInpaintOnly(Preprocessor):
@ -47,9 +36,6 @@ class PreprocessorInpaintOnly(Preprocessor):
self.accepts_mask = True
self.requires_mask = True
def get_display_image(self, input_image: np.ndarray, result):
return visualize_inpaint_mask(result)
def __call__(
self,
input_image,
@ -59,7 +45,10 @@ class PreprocessorInpaintOnly(Preprocessor):
slider_3=None,
**kwargs
):
return input_image
return Preprocessor.Result(
value=input_image,
display_images=visualize_inpaint_mask(input_image)[None, :, :, :],
)
Preprocessor.add_supported_preprocessor(PreprocessorInpaint())

View File

@ -2,7 +2,7 @@ import cv2
import numpy as np
from ..supported_preprocessor import Preprocessor, PreprocessorParameter
from ..utils import resize_image_with_pad
from ..utils import resize_image_with_pad, visualize_inpaint_mask
class PreprocessorLamaInpaint(Preprocessor):
@ -15,12 +15,6 @@ class PreprocessorLamaInpaint(Preprocessor):
self.accepts_mask = True
self.requires_mask = True
def get_display_image(self, input_image: np.ndarray, result: np.ndarray):
"""For lama inpaint, display image should not contain mask."""
assert result.ndim == 3
assert result.shape[2] == 4
return result[:, :, :3]
def __call__(
self,
input_image,
@ -56,7 +50,13 @@ class PreprocessorLamaInpaint(Preprocessor):
fin_color = fin_color.clip(0, 255).astype(np.uint8)
result = np.concatenate([fin_color, raw_mask], axis=2)
return result
return Preprocessor.Result(
value=result,
display_images=[
result[:, :, :3],
visualize_inpaint_mask(result),
],
)
Preprocessor.add_supported_preprocessor(PreprocessorLamaInpaint())

View File

@ -93,7 +93,7 @@ class LegacyPreprocessor(Preprocessor):
def __call__(
self,
input_image,
resolution,
resolution=512,
slider_1=None,
slider_2=None,
slider_3=None,

View File

@ -0,0 +1,169 @@
# https://github.com/ToTheBeginning/PuLID
import torch
import cv2
import numpy as np
from typing import Optional, List
from dataclasses import dataclass
from facexlib.parsing import init_parsing_model
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torchvision.transforms.functional import normalize
from ..supported_preprocessor import Preprocessor, PreprocessorParameter
from scripts.utils import npimg2tensor, tensor2npimg, resize_image_with_pad
def to_gray(img):
x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
x = x.repeat(1, 3, 1, 1)
return x
class PreprocessorFaceXLib(Preprocessor):
def __init__(self):
super().__init__(name="facexlib")
self.tags = []
self.slider_resolution = PreprocessorParameter(visible=False)
self.model: Optional[FaceRestoreHelper] = None
def load_model(self):
if self.model is None:
self.model = FaceRestoreHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model="retinaface_resnet50",
save_ext="png",
device=self.device,
)
self.model.face_parse = init_parsing_model(
model_name="bisenet", device=self.device
)
self.model.face_parse.to(device=self.device)
self.model.face_det.to(device=self.device)
return self.model
def unload(self) -> bool:
"""@Override"""
if self.model is not None:
self.model.face_parse.to(device="cpu")
self.model.face_det.to(device="cpu")
return True
return False
def __call__(
self,
input_image,
resolution=512,
slider_1=None,
slider_2=None,
slider_3=None,
input_mask=None,
return_tensor=False,
**kwargs
):
"""
@Override
Returns black and white face features image with background removed.
"""
self.load_model()
self.model.clean_all()
input_image, _ = resize_image_with_pad(input_image, resolution)
# using facexlib to detect and align face
image_bgr = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
self.model.read_image(image_bgr)
self.model.get_face_landmarks_5(only_center_face=True)
self.model.align_warp_face()
if len(self.model.cropped_faces) == 0:
raise RuntimeError("facexlib align face fail")
align_face = self.model.cropped_faces[0]
align_face_rgb = cv2.cvtColor(align_face, cv2.COLOR_BGR2RGB)
input = npimg2tensor(align_face_rgb)
input = input.to(self.device)
parsing_out = self.model.face_parse(
normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
)[0]
parsing_out = parsing_out.argmax(dim=1, keepdim=True)
bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
bg = sum(parsing_out == i for i in bg_label).bool()
white_image = torch.ones_like(input)
# only keep the face features
face_features_image = torch.where(bg, white_image, to_gray(input))
if return_tensor:
return face_features_image
else:
return tensor2npimg(face_features_image)
@dataclass
class PuLIDProjInput:
id_ante_embedding: torch.Tensor
id_cond_vit: torch.Tensor
id_vit_hidden: List[torch.Tensor]
class PreprocessorPuLID(Preprocessor):
"""PuLID preprocessor."""
def __init__(self):
super().__init__(name="ip-adapter_pulid")
self.tags = ["IP-Adapter"]
self.slider_resolution = PreprocessorParameter(visible=False)
self.returns_image = False
self.preprocessors_deps = [
"facexlib",
"instant_id_face_embedding",
"EVA02-CLIP-L-14-336",
]
def facexlib_detect(self, input_image: np.ndarray) -> torch.Tensor:
facexlib_preprocessor = Preprocessor.get_preprocessor("facexlib")
return facexlib_preprocessor(input_image, return_tensor=True)
def insightface_antelopev2_detect(self, input_image: np.ndarray) -> torch.Tensor:
antelopev2_preprocessor = Preprocessor.get_preprocessor(
"instant_id_face_embedding"
)
return antelopev2_preprocessor(input_image)
def unload(self) -> bool:
unloaded = False
for p_name in self.preprocessors_deps:
p = Preprocessor.get_preprocessor(p_name)
if p is not None:
unloaded = unloaded or p.unload()
return unloaded
def __call__(
self,
input_image,
resolution,
slider_1=None,
slider_2=None,
slider_3=None,
input_mask=None,
**kwargs
) -> Preprocessor.Result:
id_ante_embedding = self.insightface_antelopev2_detect(input_image)
if id_ante_embedding.ndim == 1:
id_ante_embedding = id_ante_embedding.unsqueeze(0)
face_features_image = self.facexlib_detect(input_image)
evaclip_preprocessor = Preprocessor.get_preprocessor("EVA02-CLIP-L-14-336")
assert (
evaclip_preprocessor is not None
), "EVA02-CLIP-L-14-336 preprocessor not found! Please install sd-webui-controlnet-evaclip"
r = evaclip_preprocessor(face_features_image)
return Preprocessor.Result(
value=PuLIDProjInput(
id_ante_embedding=id_ante_embedding,
id_cond_vit=r.id_cond_vit,
id_vit_hidden=r.id_vit_hidden,
),
display_images=[tensor2npimg(face_features_image)],
)
Preprocessor.add_supported_preprocessor(PreprocessorFaceXLib())
Preprocessor.add_supported_preprocessor(PreprocessorPuLID())

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass, field
import numpy as np
import torch
from modules import shared
from modules import shared, devices
from scripts.logging import logger
from scripts.utils import ndarray_lru_cache
@ -101,6 +101,7 @@ class Preprocessor(ABC):
use_soft_projection_in_hr_fix = False
expand_mask_when_resize_and_fill = False
model: Optional[torch.nn.Module] = None
device = devices.get_device_for("controlnet")
all_processors: ClassVar[Dict[str, "Preprocessor"]] = {}
all_processors_by_name: ClassVar[Dict[str, "Preprocessor"]] = {}
@ -183,18 +184,19 @@ class Preprocessor(ABC):
class Result(NamedTuple):
value: Any
# The display image shown on UI.
display_image: np.ndarray
def get_display_image(self, input_image: np.ndarray, result):
return result if self.returns_image else input_image
# The display images shown on UI.
display_images: List[np.ndarray]
def cached_call(self, input_image, *args, **kwargs) -> "Preprocessor.Result":
"""The function exposed that also returns an image for display."""
result = self._cached_call(input_image, *args, **kwargs)
return Preprocessor.Result(
value=result, display_image=self.get_display_image(input_image, result)
)
if isinstance(result, Preprocessor.Result):
return result
else:
return Preprocessor.Result(
value=result,
display_images=[result if self.returns_image else input_image],
)
@ndarray_lru_cache(max_size=CACHE_SIZE)
def _cached_call(self, *args, **kwargs):

View File

@ -1,3 +1,4 @@
from einops import rearrange
import torch
import os
import functools
@ -105,8 +106,9 @@ def timer_decorator(func):
class TimeMeta(type):
""" Metaclass to record execution time on all methods of the
child class. """
"""Metaclass to record execution time on all methods of the
child class."""
def __new__(cls, name, bases, attrs):
for attr_name, attr_value in attrs.items():
if callable(attr_value):
@ -161,7 +163,9 @@ def read_image(img_path: str) -> str:
return encoded_image
def read_image_dir(img_dir: str, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> List[str]:
def read_image_dir(
img_dir: str, suffixes=(".png", ".jpg", ".jpeg", ".webp")
) -> List[str]:
"""Try read all images in given img_dir."""
images = []
for filename in os.listdir(img_dir):
@ -175,7 +179,7 @@ def read_image_dir(img_dir: str, suffixes=('.png', '.jpg', '.jpeg', '.webp')) ->
def align_dim_latent(x: int) -> int:
""" Align the pixel dimension (w/h) to latent dimension.
"""Align the pixel dimension (w/h) to latent dimension.
Stable diffusion 1:8 ratio for latent/pixel, i.e.,
1 latent unit == 8 pixel unit."""
return (x // 8) * 8
@ -203,9 +207,34 @@ def resize_image_with_pad(img: np.ndarray, resolution: int):
W_target = int(np.round(float(W_raw) * k))
img = cv2.resize(img, (W_target, H_target), interpolation=interpolation)
H_pad, W_pad = pad64(H_target), pad64(W_target)
img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode='edge')
img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode="edge")
def remove_pad(x):
return safer_memory(x[:H_target, :W_target])
return safer_memory(img_padded), remove_pad
return safer_memory(img_padded), remove_pad
def npimg2tensor(img: np.ndarray) -> torch.Tensor:
"""Convert numpy img ([H, W, C]) to tensor ([1, C, H, W])"""
return rearrange(torch.from_numpy(img).float() / 255.0, "h w c -> 1 c h w")
def tensor2npimg(t: torch.Tensor) -> np.ndarray:
"""Convert tensor ([1, C, H, W]) to numpy RGB img ([H, W, C])"""
return (
(rearrange(t, "1 c h w -> h w c") * 255.0)
.to(dtype=torch.uint8)
.cpu()
.numpy()
)
def visualize_inpaint_mask(img):
if img.ndim == 3 and img.shape[2] == 4:
result = img.copy()
mask = result[:, :, 3]
mask = 255 - mask // 2
result[:, :, 3] = mask
return np.ascontiguousarray(result.copy())
return img