Support PuLID (#2838)
* Add preprocessors * Fix resolution param * Fix various issues * Add PuLID attn * remove unused import * Resize img before passing to facexlib * safe unloadpull/2842/head
parent
36a310f599
commit
784b6d01a7
|
|
@ -11,7 +11,7 @@ from modules import scripts, processing, shared
|
|||
from modules.safe import unsafe_torch_load
|
||||
from scripts import global_state
|
||||
from scripts.logging import logger
|
||||
from scripts.enums import HiResFixOption
|
||||
from scripts.enums import HiResFixOption, PuLIDMode
|
||||
from scripts.supported_preprocessor import Preprocessor, PreprocessorParameter
|
||||
|
||||
from modules.api import api
|
||||
|
|
@ -207,6 +207,10 @@ class ControlNetUnit:
|
|||
# The effective region mask that unit's effect should be restricted to.
|
||||
effective_region_mask: Optional[np.ndarray] = None
|
||||
|
||||
# The weight mode for PuLID.
|
||||
# https://github.com/ToTheBeginning/PuLID
|
||||
pulid_mode: PuLIDMode = PuLIDMode.FIDELITY
|
||||
|
||||
# The tensor input for ipadapter. When this field is set in the API,
|
||||
# the base64string will be interpret by torch.load to reconstruct ipadapter
|
||||
# preprocessor output.
|
||||
|
|
@ -243,6 +247,7 @@ class ControlNetUnit:
|
|||
# provide much information when restoring the unit.
|
||||
"inpaint_crop_input_image",
|
||||
"effective_region_mask",
|
||||
"pulid_mode",
|
||||
]
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -6,3 +6,4 @@ addict
|
|||
yapf
|
||||
albumentations==1.4.3
|
||||
matplotlib
|
||||
facexlib
|
||||
|
|
|
|||
|
|
@ -179,7 +179,7 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
|
|||
low_vram=low_vram,
|
||||
)
|
||||
if preprocessor.returns_image:
|
||||
images.append(encode_to_base64(result.display_image))
|
||||
images.append(encode_to_base64(result.display_images[0]))
|
||||
else:
|
||||
tensors.append(encode_tensor_to_base64(result.value))
|
||||
|
||||
|
|
|
|||
|
|
@ -17,12 +17,14 @@ from einops import rearrange
|
|||
import scripts.preprocessor as preprocessor_init # noqa
|
||||
from annotator.util import HWC3
|
||||
from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils
|
||||
from internal_controlnet.external_code import ControlMode
|
||||
from scripts.controlnet_lora import bind_control_lora, unbind_control_lora
|
||||
from scripts.controlnet_lllite import clear_all_lllite
|
||||
from scripts.ipadapter.plugable_ipadapter import ImageEmbed, clear_all_ip_adapter
|
||||
from scripts.ipadapter.pulid_attn import PULID_SETTING_FIDELITY, PULID_SETTING_STYLE
|
||||
from scripts.utils import load_state_dict, get_unique_axis0, align_dim_latent
|
||||
from scripts.hook import ControlParams, UnetHook, HackedImageRNG
|
||||
from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption
|
||||
from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption, PuLIDMode
|
||||
from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
|
||||
from scripts.controlnet_ui.photopea import Photopea
|
||||
from scripts.logging import logger
|
||||
|
|
@ -279,6 +281,7 @@ def get_control(
|
|||
)
|
||||
detected_map = result.value
|
||||
is_image = preprocessor.returns_image
|
||||
# TODO: Refactor img control detection logic.
|
||||
if high_res_fix:
|
||||
if is_image:
|
||||
hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x)
|
||||
|
|
@ -293,7 +296,8 @@ def get_control(
|
|||
store_detected_map(detected_map, unit.module)
|
||||
else:
|
||||
control = detected_map
|
||||
store_detected_map(input_image, unit.module)
|
||||
for image in result.display_images:
|
||||
store_detected_map(image, unit.module)
|
||||
|
||||
if control_model_type == ControlModelType.T2I_StyleAdapter:
|
||||
control = control['last_hidden_state']
|
||||
|
|
@ -1092,8 +1096,8 @@ class Script(scripts.Script, metaclass=(
|
|||
global_average_pooling=global_average_pooling,
|
||||
hr_hint_cond=hr_control,
|
||||
hr_option=HiResFixOption.from_value(unit.hr_option) if high_res_fix else HiResFixOption.BOTH,
|
||||
soft_injection=control_mode != external_code.ControlMode.BALANCED,
|
||||
cfg_injection=control_mode == external_code.ControlMode.CONTROL,
|
||||
soft_injection=control_mode != ControlMode.BALANCED,
|
||||
cfg_injection=control_mode == ControlMode.CONTROL,
|
||||
effective_region_mask=(
|
||||
get_pytorch_control(unit.effective_region_mask)[:, 0:1, :, :]
|
||||
if unit.effective_region_mask is not None
|
||||
|
|
@ -1190,7 +1194,7 @@ class Script(scripts.Script, metaclass=(
|
|||
|
||||
is_low_vram = any(unit.low_vram for unit in self.enabled_units)
|
||||
|
||||
for i, param in enumerate(forward_params):
|
||||
for i, (param, unit) in enumerate(zip(forward_params, self.enabled_units)):
|
||||
if param.control_model_type == ControlModelType.IPAdapter:
|
||||
if param.advanced_weighting is not None:
|
||||
logger.info(f"IP-Adapter using advanced weighting {param.advanced_weighting}")
|
||||
|
|
@ -1205,6 +1209,13 @@ class Script(scripts.Script, metaclass=(
|
|||
weight = param.weight
|
||||
|
||||
h, w, hr_y, hr_x = Script.get_target_dimensions(p)
|
||||
pulid_mode = PuLIDMode(unit.pulid_mode)
|
||||
if pulid_mode == PuLIDMode.STYLE:
|
||||
pulid_attn_setting = PULID_SETTING_STYLE
|
||||
else:
|
||||
assert pulid_mode == PuLIDMode.FIDELITY
|
||||
pulid_attn_setting = PULID_SETTING_FIDELITY
|
||||
|
||||
param.control_model.hook(
|
||||
model=unet,
|
||||
preprocessor_outputs=param.hint_cond,
|
||||
|
|
@ -1215,6 +1226,7 @@ class Script(scripts.Script, metaclass=(
|
|||
latent_width=w // 8,
|
||||
latent_height=h // 8,
|
||||
effective_region_mask=param.effective_region_mask,
|
||||
pulid_attn_setting=pulid_attn_setting,
|
||||
)
|
||||
if param.control_model_type == ControlModelType.Controlllite:
|
||||
param.control_model.hook(
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from scripts.controlnet_ui.openpose_editor import OpenposeEditor
|
|||
from scripts.controlnet_ui.preset import ControlNetPresetUI
|
||||
from scripts.controlnet_ui.photopea import Photopea
|
||||
from scripts.controlnet_ui.advanced_weight_control import AdvancedWeightControl
|
||||
from scripts.enums import InputMode
|
||||
from scripts.enums import InputMode, PuLIDMode
|
||||
from modules import shared
|
||||
from modules.ui_components import FormRow, FormHTML, ToolButton
|
||||
|
||||
|
|
@ -287,6 +287,7 @@ class ControlNetUiGroup(object):
|
|||
self.batch_image_dir_state = None
|
||||
self.output_dir_state = None
|
||||
self.advanced_weighting = gr.State(None)
|
||||
self.pulid_mode = None
|
||||
|
||||
# API-only fields
|
||||
self.ipadapter_input = gr.State(None)
|
||||
|
|
@ -626,6 +627,15 @@ class ControlNetUiGroup(object):
|
|||
visible=False,
|
||||
)
|
||||
|
||||
self.pulid_mode = gr.Radio(
|
||||
choices=[e.value for e in PuLIDMode],
|
||||
value=self.default_unit.pulid_mode.value,
|
||||
label="PuLID Mode",
|
||||
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_pulid_mode_radio",
|
||||
elem_classes="controlnet_pulid_mode_radio",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
self.loopback = gr.Checkbox(
|
||||
label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation",
|
||||
value=self.default_unit.loopback,
|
||||
|
|
@ -673,6 +683,7 @@ class ControlNetUiGroup(object):
|
|||
self.save_detected_map,
|
||||
self.advanced_weighting,
|
||||
self.effective_region_mask,
|
||||
self.pulid_mode,
|
||||
)
|
||||
|
||||
unit = gr.State(self.default_unit)
|
||||
|
|
@ -947,7 +958,7 @@ class ControlNetUiGroup(object):
|
|||
|
||||
return (
|
||||
# Update to `generated_image`
|
||||
gr.update(value=result.display_image, visible=True, interactive=False),
|
||||
gr.update(value=result.display_images[0], visible=True, interactive=False),
|
||||
# preprocessor_preview
|
||||
gr.update(value=True),
|
||||
# openpose editor
|
||||
|
|
@ -1118,6 +1129,14 @@ class ControlNetUiGroup(object):
|
|||
show_progress=False,
|
||||
)
|
||||
|
||||
def register_shift_pulid_mode(self):
|
||||
self.model.change(
|
||||
fn=lambda model: gr.update(visible="pulid" in model.lower()),
|
||||
inputs=[self.model],
|
||||
outputs=[self.pulid_mode],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
def register_sync_batch_dir(self):
|
||||
def determine_batch_dir(batch_dir, fallback_dir, fallback_fallback_dir):
|
||||
if batch_dir:
|
||||
|
|
@ -1220,6 +1239,7 @@ class ControlNetUiGroup(object):
|
|||
self.register_build_sliders()
|
||||
self.register_shift_preview()
|
||||
self.register_shift_upload_mask()
|
||||
self.register_shift_pulid_mode()
|
||||
self.register_create_canvas()
|
||||
self.register_clear_preview()
|
||||
self.register_multi_images_upload()
|
||||
|
|
|
|||
|
|
@ -247,3 +247,8 @@ class InputMode(Enum):
|
|||
# Input is a directory. 1 generation. Each generation takes N input image
|
||||
# from the directory.
|
||||
MERGE = "merge"
|
||||
|
||||
|
||||
class PuLIDMode(Enum):
|
||||
FIDELITY = "Fidelity"
|
||||
STYLE = "Extremely style"
|
||||
|
|
|
|||
|
|
@ -269,3 +269,65 @@ class Resampler(nn.Module):
|
|||
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
||||
|
||||
|
||||
class PuLIDEncoder(nn.Module):
|
||||
def __init__(self, width=1280, context_dim=2048, num_token=5):
|
||||
super().__init__()
|
||||
self.num_token = num_token
|
||||
self.context_dim = context_dim
|
||||
h1 = min((context_dim * num_token) // 4, 1024)
|
||||
h2 = min((context_dim * num_token) // 2, 1024)
|
||||
self.body = nn.Sequential(
|
||||
nn.Linear(width, h1),
|
||||
nn.LayerNorm(h1),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(h1, h2),
|
||||
nn.LayerNorm(h2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(h2, context_dim * num_token),
|
||||
)
|
||||
|
||||
for i in range(5):
|
||||
setattr(
|
||||
self,
|
||||
f"mapping_{i}",
|
||||
nn.Sequential(
|
||||
nn.Linear(1024, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(1024, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(1024, context_dim),
|
||||
),
|
||||
)
|
||||
|
||||
setattr(
|
||||
self,
|
||||
f"mapping_patch_{i}",
|
||||
nn.Sequential(
|
||||
nn.Linear(1024, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(1024, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(1024, context_dim),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x, y):
|
||||
# x shape [N, C]
|
||||
x = self.body(x)
|
||||
x = x.reshape(-1, self.num_token, self.context_dim)
|
||||
|
||||
hidden_states = ()
|
||||
for i, emb in enumerate(y):
|
||||
hidden_state = getattr(self, f"mapping_{i}")(emb[:, :1]) + getattr(
|
||||
self, f"mapping_patch_{i}"
|
||||
)(emb[:, 1:]).mean(dim=1, keepdim=True)
|
||||
hidden_states += (hidden_state,)
|
||||
hidden_states = torch.cat(hidden_states, dim=1)
|
||||
|
||||
return torch.cat([x, hidden_states], dim=1)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from .image_proj_models import (
|
|||
MLPProjModel,
|
||||
MLPProjModelFaceId,
|
||||
ProjModelFaceIdPlus,
|
||||
PuLIDEncoder,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -71,6 +72,7 @@ class IPAdapterModel(torch.nn.Module):
|
|||
is_faceid: bool,
|
||||
is_portrait: bool,
|
||||
is_instantid: bool,
|
||||
is_pulid: bool,
|
||||
is_v2: bool,
|
||||
):
|
||||
super().__init__()
|
||||
|
|
@ -85,9 +87,12 @@ class IPAdapterModel(torch.nn.Module):
|
|||
self.is_v2 = is_v2
|
||||
self.is_faceid = is_faceid
|
||||
self.is_instantid = is_instantid
|
||||
self.is_pulid = is_pulid
|
||||
self.clip_extra_context_tokens = 16 if (self.is_plus or is_portrait) else 4
|
||||
|
||||
if is_instantid:
|
||||
if self.is_pulid:
|
||||
self.image_proj_model = PuLIDEncoder()
|
||||
elif self.is_instantid:
|
||||
self.image_proj_model = self.init_proj_instantid()
|
||||
elif is_faceid:
|
||||
self.image_proj_model = self.init_proj_faceid()
|
||||
|
|
@ -235,6 +240,34 @@ class IPAdapterModel(torch.nn.Module):
|
|||
self.image_proj_model(torch.zeros_like(prompt_image_emb)),
|
||||
)
|
||||
|
||||
def _get_image_embeds_pulid(self, pulid_proj_input) -> ImageEmbed:
|
||||
"""Get image embeds for pulid."""
|
||||
id_cond = torch.cat(
|
||||
[
|
||||
pulid_proj_input.id_ante_embedding.to(
|
||||
device=self.device, dtype=torch.float32
|
||||
),
|
||||
pulid_proj_input.id_cond_vit.to(
|
||||
device=self.device, dtype=torch.float32
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
id_vit_hidden = [
|
||||
t.to(device=self.device, dtype=torch.float32)
|
||||
for t in pulid_proj_input.id_vit_hidden
|
||||
]
|
||||
return ImageEmbed(
|
||||
self.image_proj_model(
|
||||
id_cond,
|
||||
id_vit_hidden,
|
||||
),
|
||||
self.image_proj_model(
|
||||
torch.zeros_like(id_cond),
|
||||
[torch.zeros_like(t) for t in id_vit_hidden],
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load(state_dict: dict, model_name: str) -> IPAdapterModel:
|
||||
"""
|
||||
|
|
@ -245,6 +278,7 @@ class IPAdapterModel(torch.nn.Module):
|
|||
is_v2 = "v2" in model_name
|
||||
is_faceid = "faceid" in model_name
|
||||
is_instantid = "instant_id" in model_name
|
||||
is_pulid = "pulid" in model_name.lower()
|
||||
is_portrait = "portrait" in model_name
|
||||
is_full = "proj.3.weight" in state_dict["image_proj"]
|
||||
is_plus = (
|
||||
|
|
@ -256,8 +290,8 @@ class IPAdapterModel(torch.nn.Module):
|
|||
sdxl = cross_attention_dim == 2048
|
||||
sdxl_plus = sdxl and is_plus
|
||||
|
||||
if is_instantid:
|
||||
# InstantID does not use clip embedding.
|
||||
if is_instantid or is_pulid:
|
||||
# InstantID/PuLID does not use clip embedding.
|
||||
clip_embeddings_dim = None
|
||||
elif is_faceid:
|
||||
if is_plus:
|
||||
|
|
@ -291,10 +325,13 @@ class IPAdapterModel(torch.nn.Module):
|
|||
is_portrait=is_portrait,
|
||||
is_instantid=is_instantid,
|
||||
is_v2=is_v2,
|
||||
is_pulid=is_pulid,
|
||||
)
|
||||
|
||||
def get_image_emb(self, preprocessor_output) -> ImageEmbed:
|
||||
if self.is_instantid:
|
||||
if self.is_pulid:
|
||||
return self._get_image_embeds_pulid(preprocessor_output)
|
||||
elif self.is_instantid:
|
||||
return self._get_image_embeds_instantid(preprocessor_output)
|
||||
elif self.is_faceid and self.is_plus:
|
||||
# Note: FaceID plus uses both face_embed and clip_embed.
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import itertools
|
||||
import torch
|
||||
import math
|
||||
from typing import Union, Dict, Optional
|
||||
from typing import Union, Dict, Optional, Callable
|
||||
|
||||
from .pulid_attn import PuLIDAttnSetting
|
||||
from .ipadapter_model import ImageEmbed, IPAdapterModel
|
||||
from ..enums import StableDiffusionVersion, TransformerID
|
||||
|
||||
|
|
@ -93,7 +94,7 @@ def clear_all_ip_adapter():
|
|||
class PlugableIPAdapter(torch.nn.Module):
|
||||
def __init__(self, ipadapter: IPAdapterModel):
|
||||
super().__init__()
|
||||
self.ipadapter = ipadapter
|
||||
self.ipadapter: IPAdapterModel = ipadapter
|
||||
self.disable_memory_management = True
|
||||
self.dtype = None
|
||||
self.weight: Union[float, Dict[int, float]] = 1.0
|
||||
|
|
@ -103,6 +104,7 @@ class PlugableIPAdapter(torch.nn.Module):
|
|||
self.latent_width: int = 0
|
||||
self.latent_height: int = 0
|
||||
self.effective_region_mask = None
|
||||
self.pulid_attn_setting: Optional[PuLIDAttnSetting] = None
|
||||
|
||||
def reset(self):
|
||||
self.cache = {}
|
||||
|
|
@ -118,6 +120,7 @@ class PlugableIPAdapter(torch.nn.Module):
|
|||
latent_width: int,
|
||||
latent_height: int,
|
||||
effective_region_mask: Optional[torch.Tensor],
|
||||
pulid_attn_setting: Optional[PuLIDAttnSetting] = None,
|
||||
dtype=torch.float32,
|
||||
):
|
||||
global current_model
|
||||
|
|
@ -128,6 +131,7 @@ class PlugableIPAdapter(torch.nn.Module):
|
|||
self.latent_width = latent_width
|
||||
self.latent_height = latent_height
|
||||
self.effective_region_mask = effective_region_mask
|
||||
self.pulid_attn_setting = pulid_attn_setting
|
||||
|
||||
self.cache = {}
|
||||
|
||||
|
|
@ -186,7 +190,9 @@ class PlugableIPAdapter(torch.nn.Module):
|
|||
# sequence_length = (latent_height * factor) * (latent_height * factor)
|
||||
# sequence_length = (latent_height * latent_height) * factor ^ 2
|
||||
factor = math.sqrt(sequence_length / (self.latent_width * self.latent_height))
|
||||
assert factor > 0, f"{factor}, {sequence_length}, {self.latent_width}, {self.latent_height}"
|
||||
assert (
|
||||
factor > 0
|
||||
), f"{factor}, {sequence_length}, {self.latent_width}, {self.latent_height}"
|
||||
mask_h = int(self.latent_height * factor)
|
||||
mask_w = int(self.latent_width * factor)
|
||||
|
||||
|
|
@ -199,6 +205,71 @@ class PlugableIPAdapter(torch.nn.Module):
|
|||
mask = mask.view(mask.shape[0], -1, 1).repeat(1, 1, out.shape[2])
|
||||
return out * mask
|
||||
|
||||
def attn_eval(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
cond_uncond_image_emb: torch.Tensor,
|
||||
attn_heads: int,
|
||||
head_dim: int,
|
||||
emb_to_k: Callable[[torch.Tensor], torch.Tensor],
|
||||
emb_to_v: Callable[[torch.Tensor], torch.Tensor],
|
||||
):
|
||||
if self.ipadapter.is_pulid:
|
||||
assert self.pulid_attn_setting is not None
|
||||
return self.pulid_attn_setting.eval(
|
||||
hidden_states,
|
||||
query,
|
||||
cond_uncond_image_emb,
|
||||
attn_heads,
|
||||
head_dim,
|
||||
emb_to_k,
|
||||
emb_to_v,
|
||||
)
|
||||
else:
|
||||
return self._attn_eval_ipadapter(
|
||||
hidden_states,
|
||||
query,
|
||||
cond_uncond_image_emb,
|
||||
attn_heads,
|
||||
head_dim,
|
||||
emb_to_k,
|
||||
emb_to_v,
|
||||
)
|
||||
|
||||
def _attn_eval_ipadapter(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
cond_uncond_image_emb: torch.Tensor,
|
||||
attn_heads: int,
|
||||
head_dim: int,
|
||||
emb_to_k: Callable[[torch.Tensor], torch.Tensor],
|
||||
emb_to_v: Callable[[torch.Tensor], torch.Tensor],
|
||||
):
|
||||
assert hidden_states.ndim == 3
|
||||
batch_size, sequence_length, inner_dim = hidden_states.shape
|
||||
ip_k = emb_to_k(cond_uncond_image_emb)
|
||||
ip_v = emb_to_v(cond_uncond_image_emb)
|
||||
|
||||
ip_k, ip_v = map(
|
||||
lambda t: t.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2),
|
||||
(ip_k, ip_v),
|
||||
)
|
||||
assert ip_k.dtype == ip_v.dtype
|
||||
|
||||
# On MacOS, q can be float16 instead of float32.
|
||||
# https://github.com/Mikubill/sd-webui-controlnet/issues/2208
|
||||
if query.dtype != ip_k.dtype:
|
||||
ip_k = ip_k.to(dtype=query.dtype)
|
||||
ip_v = ip_v.to(dtype=query.dtype)
|
||||
|
||||
ip_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
|
||||
return ip_out
|
||||
|
||||
@torch.no_grad()
|
||||
def patch_forward(self, number: int, transformer_index: int):
|
||||
@torch.no_grad()
|
||||
|
|
@ -220,27 +291,15 @@ class PlugableIPAdapter(torch.nn.Module):
|
|||
|
||||
k_key = f"{number * 2 + 1}_to_k_ip"
|
||||
v_key = f"{number * 2 + 1}_to_v_ip"
|
||||
cond_uncond_image_emb = self.image_emb.eval(current_model.cond_mark)
|
||||
ip_k = self.call_ip(k_key, cond_uncond_image_emb, device=q.device)
|
||||
ip_v = self.call_ip(v_key, cond_uncond_image_emb, device=q.device)
|
||||
|
||||
ip_k, ip_v = map(
|
||||
lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2),
|
||||
(ip_k, ip_v),
|
||||
ip_out = self.attn_eval(
|
||||
hidden_states=x,
|
||||
query=q,
|
||||
cond_uncond_image_emb=self.image_emb.eval(current_model.cond_mark),
|
||||
attn_heads=h,
|
||||
head_dim=head_dim,
|
||||
emb_to_k=lambda emb: self.call_ip(k_key, emb, device=q.device),
|
||||
emb_to_v=lambda emb: self.call_ip(v_key, emb, device=q.device),
|
||||
)
|
||||
assert ip_k.dtype == ip_v.dtype
|
||||
|
||||
# On MacOS, q can be float16 instead of float32.
|
||||
# https://github.com/Mikubill/sd-webui-controlnet/issues/2208
|
||||
if q.dtype != ip_k.dtype:
|
||||
ip_k = ip_k.to(dtype=q.dtype)
|
||||
ip_v = ip_v.to(dtype=q.dtype)
|
||||
|
||||
ip_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
|
||||
|
||||
return self.apply_effective_region_mask(ip_out * weight)
|
||||
|
||||
return forward
|
||||
|
|
|
|||
|
|
@ -166,6 +166,12 @@ ipadapter_presets: List[IPAdapterPreset] = [
|
|||
model="ip-adapter-faceid-portrait_sdxl",
|
||||
sd_version=StableDiffusionVersion.SDXL,
|
||||
),
|
||||
IPAdapterPreset(
|
||||
name="pulid",
|
||||
module="ip-adapter_pulid",
|
||||
model="ip-adapter_pulid_sdxl_fp16",
|
||||
sd_version=StableDiffusionVersion.SDXL,
|
||||
),
|
||||
]
|
||||
|
||||
_preset_by_model = {p.model: p for p in ipadapter_presets}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,94 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
|
||||
@dataclass
|
||||
class PuLIDAttnSetting:
|
||||
num_zero: int = 0
|
||||
ortho: bool = False
|
||||
ortho_v2: bool = False
|
||||
|
||||
def eval(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
id_embedding: torch.Tensor,
|
||||
attn_heads: int,
|
||||
head_dim: int,
|
||||
id_to_k: Callable[[torch.Tensor], torch.Tensor],
|
||||
id_to_v: Callable[[torch.Tensor], torch.Tensor],
|
||||
):
|
||||
assert hidden_states.ndim == 3
|
||||
batch_size, sequence_length, inner_dim = hidden_states.shape
|
||||
|
||||
if self.num_zero == 0:
|
||||
id_key = id_to_k(id_embedding).to(query.dtype)
|
||||
id_value = id_to_v(id_embedding).to(query.dtype)
|
||||
else:
|
||||
zero_tensor = torch.zeros(
|
||||
(id_embedding.size(0), self.num_zero, id_embedding.size(-1)),
|
||||
dtype=id_embedding.dtype,
|
||||
device=id_embedding.device,
|
||||
)
|
||||
id_key = id_to_k(torch.cat((id_embedding, zero_tensor), dim=1)).to(
|
||||
query.dtype
|
||||
)
|
||||
id_value = id_to_v(torch.cat((id_embedding, zero_tensor), dim=1)).to(
|
||||
query.dtype
|
||||
)
|
||||
|
||||
id_key = id_key.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
||||
id_value = id_value.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
id_hidden_states = F.scaled_dot_product_attention(
|
||||
query, id_key, id_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
id_hidden_states = id_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn_heads * head_dim
|
||||
)
|
||||
id_hidden_states = id_hidden_states.to(query.dtype)
|
||||
|
||||
if not self.ortho and not self.ortho_v2:
|
||||
return id_hidden_states
|
||||
elif self.ortho_v2:
|
||||
orig_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
id_hidden_states = id_hidden_states.to(torch.float32)
|
||||
attn_map = query @ id_key.transpose(-2, -1)
|
||||
attn_mean = attn_map.softmax(dim=-1).mean(dim=1)
|
||||
attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True)
|
||||
projection = (
|
||||
torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
|
||||
/ torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
|
||||
* hidden_states
|
||||
)
|
||||
orthogonal = id_hidden_states + (attn_mean - 1) * projection
|
||||
return orthogonal.to(orig_dtype)
|
||||
else:
|
||||
orig_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
id_hidden_states = id_hidden_states.to(torch.float32)
|
||||
projection = (
|
||||
torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
|
||||
/ torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
|
||||
* hidden_states
|
||||
)
|
||||
orthogonal = id_hidden_states - projection
|
||||
return orthogonal.to(orig_dtype)
|
||||
|
||||
|
||||
PULID_SETTING_FIDELITY = PuLIDAttnSetting(
|
||||
num_zero=8,
|
||||
ortho=False,
|
||||
ortho_v2=True,
|
||||
)
|
||||
|
||||
PULID_SETTING_STYLE = PuLIDAttnSetting(
|
||||
num_zero=16,
|
||||
ortho=True,
|
||||
ortho_v2=False,
|
||||
)
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
from .pulid import *
|
||||
from .inpaint import *
|
||||
from .lama_inpaint import *
|
||||
from .ip_adapter_auto import *
|
||||
|
|
|
|||
|
|
@ -1,18 +1,7 @@
|
|||
import numpy as np
|
||||
|
||||
from scripts.utils import visualize_inpaint_mask
|
||||
from ..supported_preprocessor import Preprocessor, PreprocessorParameter
|
||||
|
||||
|
||||
def visualize_inpaint_mask(img):
|
||||
if img.ndim == 3 and img.shape[2] == 4:
|
||||
result = img.copy()
|
||||
mask = result[:, :, 3]
|
||||
mask = 255 - mask // 2
|
||||
result[:, :, 3] = mask
|
||||
return np.ascontiguousarray(result.copy())
|
||||
return img
|
||||
|
||||
|
||||
class PreprocessorInpaint(Preprocessor):
|
||||
def __init__(self):
|
||||
super().__init__(name="inpaint")
|
||||
|
|
@ -23,9 +12,6 @@ class PreprocessorInpaint(Preprocessor):
|
|||
self.accepts_mask = True
|
||||
self.requires_mask = True
|
||||
|
||||
def get_display_image(self, input_image: np.ndarray, result):
|
||||
return visualize_inpaint_mask(result)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_image,
|
||||
|
|
@ -35,7 +21,10 @@ class PreprocessorInpaint(Preprocessor):
|
|||
slider_3=None,
|
||||
**kwargs
|
||||
):
|
||||
return input_image
|
||||
return Preprocessor.Result(
|
||||
value=input_image,
|
||||
display_images=visualize_inpaint_mask(input_image)[None, :, :, :],
|
||||
)
|
||||
|
||||
|
||||
class PreprocessorInpaintOnly(Preprocessor):
|
||||
|
|
@ -47,9 +36,6 @@ class PreprocessorInpaintOnly(Preprocessor):
|
|||
self.accepts_mask = True
|
||||
self.requires_mask = True
|
||||
|
||||
def get_display_image(self, input_image: np.ndarray, result):
|
||||
return visualize_inpaint_mask(result)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_image,
|
||||
|
|
@ -59,7 +45,10 @@ class PreprocessorInpaintOnly(Preprocessor):
|
|||
slider_3=None,
|
||||
**kwargs
|
||||
):
|
||||
return input_image
|
||||
return Preprocessor.Result(
|
||||
value=input_image,
|
||||
display_images=visualize_inpaint_mask(input_image)[None, :, :, :],
|
||||
)
|
||||
|
||||
|
||||
Preprocessor.add_supported_preprocessor(PreprocessorInpaint())
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import cv2
|
|||
import numpy as np
|
||||
|
||||
from ..supported_preprocessor import Preprocessor, PreprocessorParameter
|
||||
from ..utils import resize_image_with_pad
|
||||
from ..utils import resize_image_with_pad, visualize_inpaint_mask
|
||||
|
||||
|
||||
class PreprocessorLamaInpaint(Preprocessor):
|
||||
|
|
@ -15,12 +15,6 @@ class PreprocessorLamaInpaint(Preprocessor):
|
|||
self.accepts_mask = True
|
||||
self.requires_mask = True
|
||||
|
||||
def get_display_image(self, input_image: np.ndarray, result: np.ndarray):
|
||||
"""For lama inpaint, display image should not contain mask."""
|
||||
assert result.ndim == 3
|
||||
assert result.shape[2] == 4
|
||||
return result[:, :, :3]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_image,
|
||||
|
|
@ -56,7 +50,13 @@ class PreprocessorLamaInpaint(Preprocessor):
|
|||
fin_color = fin_color.clip(0, 255).astype(np.uint8)
|
||||
|
||||
result = np.concatenate([fin_color, raw_mask], axis=2)
|
||||
return result
|
||||
return Preprocessor.Result(
|
||||
value=result,
|
||||
display_images=[
|
||||
result[:, :, :3],
|
||||
visualize_inpaint_mask(result),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
Preprocessor.add_supported_preprocessor(PreprocessorLamaInpaint())
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class LegacyPreprocessor(Preprocessor):
|
|||
def __call__(
|
||||
self,
|
||||
input_image,
|
||||
resolution,
|
||||
resolution=512,
|
||||
slider_1=None,
|
||||
slider_2=None,
|
||||
slider_3=None,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,169 @@
|
|||
# https://github.com/ToTheBeginning/PuLID
|
||||
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass
|
||||
from facexlib.parsing import init_parsing_model
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from ..supported_preprocessor import Preprocessor, PreprocessorParameter
|
||||
from scripts.utils import npimg2tensor, tensor2npimg, resize_image_with_pad
|
||||
|
||||
|
||||
def to_gray(img):
|
||||
x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
|
||||
x = x.repeat(1, 3, 1, 1)
|
||||
return x
|
||||
|
||||
|
||||
class PreprocessorFaceXLib(Preprocessor):
|
||||
def __init__(self):
|
||||
super().__init__(name="facexlib")
|
||||
self.tags = []
|
||||
self.slider_resolution = PreprocessorParameter(visible=False)
|
||||
self.model: Optional[FaceRestoreHelper] = None
|
||||
|
||||
def load_model(self):
|
||||
if self.model is None:
|
||||
self.model = FaceRestoreHelper(
|
||||
upscale_factor=1,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model="retinaface_resnet50",
|
||||
save_ext="png",
|
||||
device=self.device,
|
||||
)
|
||||
self.model.face_parse = init_parsing_model(
|
||||
model_name="bisenet", device=self.device
|
||||
)
|
||||
self.model.face_parse.to(device=self.device)
|
||||
self.model.face_det.to(device=self.device)
|
||||
return self.model
|
||||
|
||||
def unload(self) -> bool:
|
||||
"""@Override"""
|
||||
if self.model is not None:
|
||||
self.model.face_parse.to(device="cpu")
|
||||
self.model.face_det.to(device="cpu")
|
||||
return True
|
||||
return False
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_image,
|
||||
resolution=512,
|
||||
slider_1=None,
|
||||
slider_2=None,
|
||||
slider_3=None,
|
||||
input_mask=None,
|
||||
return_tensor=False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
@Override
|
||||
Returns black and white face features image with background removed.
|
||||
"""
|
||||
self.load_model()
|
||||
self.model.clean_all()
|
||||
input_image, _ = resize_image_with_pad(input_image, resolution)
|
||||
# using facexlib to detect and align face
|
||||
image_bgr = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
|
||||
self.model.read_image(image_bgr)
|
||||
self.model.get_face_landmarks_5(only_center_face=True)
|
||||
self.model.align_warp_face()
|
||||
if len(self.model.cropped_faces) == 0:
|
||||
raise RuntimeError("facexlib align face fail")
|
||||
align_face = self.model.cropped_faces[0]
|
||||
align_face_rgb = cv2.cvtColor(align_face, cv2.COLOR_BGR2RGB)
|
||||
input = npimg2tensor(align_face_rgb)
|
||||
input = input.to(self.device)
|
||||
parsing_out = self.model.face_parse(
|
||||
normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
)[0]
|
||||
parsing_out = parsing_out.argmax(dim=1, keepdim=True)
|
||||
bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
|
||||
bg = sum(parsing_out == i for i in bg_label).bool()
|
||||
white_image = torch.ones_like(input)
|
||||
# only keep the face features
|
||||
face_features_image = torch.where(bg, white_image, to_gray(input))
|
||||
if return_tensor:
|
||||
return face_features_image
|
||||
else:
|
||||
return tensor2npimg(face_features_image)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PuLIDProjInput:
|
||||
id_ante_embedding: torch.Tensor
|
||||
id_cond_vit: torch.Tensor
|
||||
id_vit_hidden: List[torch.Tensor]
|
||||
|
||||
|
||||
class PreprocessorPuLID(Preprocessor):
|
||||
"""PuLID preprocessor."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="ip-adapter_pulid")
|
||||
self.tags = ["IP-Adapter"]
|
||||
self.slider_resolution = PreprocessorParameter(visible=False)
|
||||
self.returns_image = False
|
||||
self.preprocessors_deps = [
|
||||
"facexlib",
|
||||
"instant_id_face_embedding",
|
||||
"EVA02-CLIP-L-14-336",
|
||||
]
|
||||
|
||||
def facexlib_detect(self, input_image: np.ndarray) -> torch.Tensor:
|
||||
facexlib_preprocessor = Preprocessor.get_preprocessor("facexlib")
|
||||
return facexlib_preprocessor(input_image, return_tensor=True)
|
||||
|
||||
def insightface_antelopev2_detect(self, input_image: np.ndarray) -> torch.Tensor:
|
||||
antelopev2_preprocessor = Preprocessor.get_preprocessor(
|
||||
"instant_id_face_embedding"
|
||||
)
|
||||
return antelopev2_preprocessor(input_image)
|
||||
|
||||
def unload(self) -> bool:
|
||||
unloaded = False
|
||||
for p_name in self.preprocessors_deps:
|
||||
p = Preprocessor.get_preprocessor(p_name)
|
||||
if p is not None:
|
||||
unloaded = unloaded or p.unload()
|
||||
return unloaded
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_image,
|
||||
resolution,
|
||||
slider_1=None,
|
||||
slider_2=None,
|
||||
slider_3=None,
|
||||
input_mask=None,
|
||||
**kwargs
|
||||
) -> Preprocessor.Result:
|
||||
id_ante_embedding = self.insightface_antelopev2_detect(input_image)
|
||||
if id_ante_embedding.ndim == 1:
|
||||
id_ante_embedding = id_ante_embedding.unsqueeze(0)
|
||||
|
||||
face_features_image = self.facexlib_detect(input_image)
|
||||
evaclip_preprocessor = Preprocessor.get_preprocessor("EVA02-CLIP-L-14-336")
|
||||
assert (
|
||||
evaclip_preprocessor is not None
|
||||
), "EVA02-CLIP-L-14-336 preprocessor not found! Please install sd-webui-controlnet-evaclip"
|
||||
r = evaclip_preprocessor(face_features_image)
|
||||
|
||||
return Preprocessor.Result(
|
||||
value=PuLIDProjInput(
|
||||
id_ante_embedding=id_ante_embedding,
|
||||
id_cond_vit=r.id_cond_vit,
|
||||
id_vit_hidden=r.id_vit_hidden,
|
||||
),
|
||||
display_images=[tensor2npimg(face_features_image)],
|
||||
)
|
||||
|
||||
|
||||
Preprocessor.add_supported_preprocessor(PreprocessorFaceXLib())
|
||||
Preprocessor.add_supported_preprocessor(PreprocessorPuLID())
|
||||
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass, field
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
from modules import shared, devices
|
||||
from scripts.logging import logger
|
||||
from scripts.utils import ndarray_lru_cache
|
||||
|
||||
|
|
@ -101,6 +101,7 @@ class Preprocessor(ABC):
|
|||
use_soft_projection_in_hr_fix = False
|
||||
expand_mask_when_resize_and_fill = False
|
||||
model: Optional[torch.nn.Module] = None
|
||||
device = devices.get_device_for("controlnet")
|
||||
|
||||
all_processors: ClassVar[Dict[str, "Preprocessor"]] = {}
|
||||
all_processors_by_name: ClassVar[Dict[str, "Preprocessor"]] = {}
|
||||
|
|
@ -183,18 +184,19 @@ class Preprocessor(ABC):
|
|||
|
||||
class Result(NamedTuple):
|
||||
value: Any
|
||||
# The display image shown on UI.
|
||||
display_image: np.ndarray
|
||||
|
||||
def get_display_image(self, input_image: np.ndarray, result):
|
||||
return result if self.returns_image else input_image
|
||||
# The display images shown on UI.
|
||||
display_images: List[np.ndarray]
|
||||
|
||||
def cached_call(self, input_image, *args, **kwargs) -> "Preprocessor.Result":
|
||||
"""The function exposed that also returns an image for display."""
|
||||
result = self._cached_call(input_image, *args, **kwargs)
|
||||
return Preprocessor.Result(
|
||||
value=result, display_image=self.get_display_image(input_image, result)
|
||||
)
|
||||
if isinstance(result, Preprocessor.Result):
|
||||
return result
|
||||
else:
|
||||
return Preprocessor.Result(
|
||||
value=result,
|
||||
display_images=[result if self.returns_image else input_image],
|
||||
)
|
||||
|
||||
@ndarray_lru_cache(max_size=CACHE_SIZE)
|
||||
def _cached_call(self, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from einops import rearrange
|
||||
import torch
|
||||
import os
|
||||
import functools
|
||||
|
|
@ -105,8 +106,9 @@ def timer_decorator(func):
|
|||
|
||||
|
||||
class TimeMeta(type):
|
||||
""" Metaclass to record execution time on all methods of the
|
||||
child class. """
|
||||
"""Metaclass to record execution time on all methods of the
|
||||
child class."""
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
for attr_name, attr_value in attrs.items():
|
||||
if callable(attr_value):
|
||||
|
|
@ -161,7 +163,9 @@ def read_image(img_path: str) -> str:
|
|||
return encoded_image
|
||||
|
||||
|
||||
def read_image_dir(img_dir: str, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> List[str]:
|
||||
def read_image_dir(
|
||||
img_dir: str, suffixes=(".png", ".jpg", ".jpeg", ".webp")
|
||||
) -> List[str]:
|
||||
"""Try read all images in given img_dir."""
|
||||
images = []
|
||||
for filename in os.listdir(img_dir):
|
||||
|
|
@ -175,7 +179,7 @@ def read_image_dir(img_dir: str, suffixes=('.png', '.jpg', '.jpeg', '.webp')) ->
|
|||
|
||||
|
||||
def align_dim_latent(x: int) -> int:
|
||||
""" Align the pixel dimension (w/h) to latent dimension.
|
||||
"""Align the pixel dimension (w/h) to latent dimension.
|
||||
Stable diffusion 1:8 ratio for latent/pixel, i.e.,
|
||||
1 latent unit == 8 pixel unit."""
|
||||
return (x // 8) * 8
|
||||
|
|
@ -203,9 +207,34 @@ def resize_image_with_pad(img: np.ndarray, resolution: int):
|
|||
W_target = int(np.round(float(W_raw) * k))
|
||||
img = cv2.resize(img, (W_target, H_target), interpolation=interpolation)
|
||||
H_pad, W_pad = pad64(H_target), pad64(W_target)
|
||||
img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode='edge')
|
||||
img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode="edge")
|
||||
|
||||
def remove_pad(x):
|
||||
return safer_memory(x[:H_target, :W_target])
|
||||
|
||||
return safer_memory(img_padded), remove_pad
|
||||
return safer_memory(img_padded), remove_pad
|
||||
|
||||
|
||||
def npimg2tensor(img: np.ndarray) -> torch.Tensor:
|
||||
"""Convert numpy img ([H, W, C]) to tensor ([1, C, H, W])"""
|
||||
return rearrange(torch.from_numpy(img).float() / 255.0, "h w c -> 1 c h w")
|
||||
|
||||
|
||||
def tensor2npimg(t: torch.Tensor) -> np.ndarray:
|
||||
"""Convert tensor ([1, C, H, W]) to numpy RGB img ([H, W, C])"""
|
||||
return (
|
||||
(rearrange(t, "1 c h w -> h w c") * 255.0)
|
||||
.to(dtype=torch.uint8)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
|
||||
def visualize_inpaint_mask(img):
|
||||
if img.ndim == 3 and img.shape[2] == 4:
|
||||
result = img.copy()
|
||||
mask = result[:, :, 3]
|
||||
mask = 255 - mask // 2
|
||||
result[:, :, 3] = mask
|
||||
return np.ascontiguousarray(result.copy())
|
||||
return img
|
||||
|
|
|
|||
Loading…
Reference in New Issue