from typing import Union import os import cv2 import insightface import numpy as np import torch import torch.nn as nn from PIL import Image from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from huggingface_hub import hf_hub_download, snapshot_download from safetensors.torch import load_file from torchvision.transforms import InterpolationMode from torchvision.transforms.functional import normalize, resize from basicsr.utils import img2tensor, tensor2img from facexlib.parsing import init_parsing_model from facexlib.utils.face_restoration_helper import FaceRestoreHelper from insightface.app import FaceAnalysis from eva_clip import create_model_and_transforms from eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from encoders_transformer import IDFormer, IDEncoder from modules.errors import log debug = log.trace if os.environ.get('SD_PULID_DEBUG', None) is not None else lambda *args, **kwargs: None class StableDiffusionXLPuLIDPipeline: def __init__(self, pipe: Union[StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline], device: torch.device, dtype: torch.dtype=None, providers: list=None, offload: bool=True, sampler=None, cache_dir=None, sdp: bool=True, version: str='v1.1', ): super().__init__() self.device = device self.dtype = dtype or torch.float16 self.pipe = pipe self.cache_dir = cache_dir self.offload = offload self.sdp = sdp self.version = version self.folder = 'models--ToTheBeginning--PuLID' debug(f'PulID init: device={self.device} dtype={self.dtype} dir={self.cache_dir} offload={self.offload} sdp={self.sdp} version={self.version}') # self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config) self.hack_unet_attn_layers(self.pipe.unet) if self.version == 'v1.1': self.id_adapter = IDFormer().to(self.device, self.dtype) else: self.id_adapter = IDEncoder().to(self.device, self.dtype) debug(f'PulID load: adapter={self.id_adapter.__class__.__name__}') self.providers = providers or ['CUDAExecutionProvider', 'CPUExecutionProvider'] debug(f'PulID load: providers={self.providers}') # preprocessors # face align and parsing self.face_helper = FaceRestoreHelper( upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device=self.device, ) self.face_helper.face_parse = init_parsing_model(model_name='bisenet', device=self.device) debug(f'PulID load: facehelper={self.face_helper.__class__.__name__}') # clip-vit backbone eva_precision = 'fp16' if self.dtype == torch.float16 or self.dtype == torch.bfloat16 else 'fp32' eva_model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', 'eva_clip', force_custom_clip=True, precision=eva_precision, device=self.device) self.clip_vision_model = eva_model.visual.to(dtype=self.dtype) debug(f'PulID load: evaclip={self.clip_vision_model.__class__.__name__} precision={eva_precision}') eva_transform_mean = getattr(self.clip_vision_model, 'image_mean', OPENAI_DATASET_MEAN) eva_transform_std = getattr(self.clip_vision_model, 'image_std', OPENAI_DATASET_STD) if not isinstance(eva_transform_mean, (list, tuple)): eva_transform_mean = (eva_transform_mean,) * 3 if not isinstance(eva_transform_std, (list, tuple)): eva_transform_std = (eva_transform_std,) * 3 self.eva_transform_mean = eva_transform_mean self.eva_transform_std = eva_transform_std # antelopev2 local_dir = os.path.join(self.cache_dir, self.folder, 'models', 'antelopev2') _loc = snapshot_download('DIAMONIK7777/antelopev2', local_dir=local_dir) self.app = FaceAnalysis( name='antelopev2', root=os.path.join(self.cache_dir, self.folder), providers=self.providers, ) debug(f'PulID load: faceanalysis={_loc}') self.app.prepare(ctx_id=0, det_size=(640, 640)) self.handler_ante = insightface.model_zoo.get_model(os.path.join(local_dir, 'glintr100.onnx')) self.handler_ante.prepare(ctx_id=0) debug(f'PulID load: handler={self.handler_ante.__class__.__name__}') self.load_pretrain() # other configs self.debug_img_list = [] # karras schedule related code, borrow from lllyasviel/Omost linear_start = 0.00085 linear_end = 0.012 timesteps = 1000 betas = torch.linspace(linear_start**0.5, linear_end**0.5, timesteps, dtype=torch.float64) ** 2 alphas = 1.0 - betas alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 self.log_sigmas = self.sigmas.log() self.sigma_data = 1.0 # default scheduler if sampler is not None: self.sampler = sampler else: from modules.pulid import sampling self.sampler = sampling.sample_dpmpp_sde @property def sigma_min(self): return self.sigmas[0] @property def sigma_max(self): return self.sigmas[-1] def timestep(self, sigma): log_sigma = sigma.log() dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device) def get_sigmas_karras(self, n, rho=7.0): ramp = torch.linspace(0, 1, n) min_inv_rho = self.sigma_min ** (1 / rho) max_inv_rho = self.sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return torch.cat([sigmas, sigmas.new_zeros([1])]) def hack_unet_attn_layers(self, unet): if self.sdp: from attention_processor import AttnProcessor2_0 as AttnProcessor from attention_processor import IDAttnProcessor2_0 as IDAttnProcessor else: from attention_processor import AttnProcessor from attention_processor import IDAttnProcessor id_adapter_attn_procs = {} for name, _ in unet.attn_processors.items(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] else: hidden_size = None if cross_attention_dim is not None: id_adapter_attn_procs[name] = IDAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, ).to(unet.device, unet.dtype) else: id_adapter_attn_procs[name] = AttnProcessor() debug(f'PulID attention: cls={IDAttnProcessor} std={AttnProcessor} len={len(id_adapter_attn_procs.keys())}') unet.set_attn_processor(id_adapter_attn_procs) self.id_adapter_attn_layers = nn.ModuleList(unet.attn_processors.values()) def load_pretrain(self): if self.version == 'v1.1': ckpt_path = hf_hub_download('guozinan/PuLID', 'pulid_v1.1.safetensors', local_dir=os.path.join(self.cache_dir, self.folder)) state_dict = load_file(ckpt_path) else: ckpt_path = hf_hub_download('guozinan/PuLID', 'pulid_v1.bin', local_dir=os.path.join(self.cache_dir, self.folder)) state_dict = torch.load(ckpt_path, map_location="cpu") debug(f'PulID load: fn="{ckpt_path}"') state_dict_dict = {} for k, v in state_dict.items(): module = k.split('.')[0] state_dict_dict.setdefault(module, {}) new_k = k[len(module) + 1 :] state_dict_dict[module][new_k] = v.to(self.dtype) for module in state_dict_dict: getattr(self, module).load_state_dict(state_dict_dict[module], strict=True) def to_gray(self, img): x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] x = x.repeat(1, 3, 1, 1) return x def get_id_embedding(self, image_list): """ Args: image in image_list: numpy rgb image, range [0, 255] """ id_cond_list = [] id_vit_hidden_list = [] self.face_helper.face_det.to(self.device) self.clip_vision_model.to(self.device) for _ii, image in enumerate(image_list): self.face_helper.clean_all() image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # get antelopev2 embedding face_info = self.app.get(image_bgr) if len(face_info) > 0: face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1] # only use the maximum face id_ante_embedding = face_info['embedding'] self.debug_img_list.append(image[int(face_info['bbox'][1]) : int(face_info['bbox'][3]), int(face_info['bbox'][0]) : int(face_info['bbox'][2])]) else: id_ante_embedding = None # using facexlib to detect and align face self.face_helper.read_image(image_bgr) self.face_helper.get_face_landmarks_5(only_center_face=True) self.face_helper.align_warp_face() if len(self.face_helper.cropped_faces) == 0: raise RuntimeError('facexlib align face fail') align_face = self.face_helper.cropped_faces[0] # incase insightface didn't detect face if id_ante_embedding is None: id_ante_embedding = self.handler_ante.get_feat(align_face) id_ante_embedding = torch.from_numpy(id_ante_embedding).to(self.device) if id_ante_embedding.ndim == 1: id_ante_embedding = id_ante_embedding.unsqueeze(0) # parsing input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 # pylint: disable=redefined-builtin input = input.to(self.device) parsing_out = self.face_helper.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0] parsing_out = parsing_out.argmax(dim=1, keepdim=True) bg_label = [0, 16, 18, 7, 8, 9, 14, 15] bg = sum(parsing_out == i for i in bg_label).bool() white_image = torch.ones_like(input) # only keep the face features face_features_image = torch.where(bg, white_image, self.to_gray(input)) self.debug_img_list.append(tensor2img(face_features_image, rgb2bgr=False)) # transform img before sending to eva-clip-vit face_features_image = resize(face_features_image, self.clip_vision_model.image_size, InterpolationMode.BICUBIC) face_features_image = normalize(face_features_image, self.eva_transform_mean, self.eva_transform_std).to(self.dtype) id_cond_vit, id_vit_hidden = self.clip_vision_model(face_features_image, return_all_features=False, return_hidden=True, shuffle=False) id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True) id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm) id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1) id_cond_list.append(id_cond) id_vit_hidden_list.append(id_vit_hidden) self.id_adapter.to(self.device) id_uncond = torch.zeros_like(id_cond_list[0]).to(self.dtype) id_vit_hidden_uncond = [] for layer_idx in range(0, len(id_vit_hidden_list[0])): id_vit_hidden_uncond.append(torch.zeros_like(id_vit_hidden_list[0][layer_idx]).to(self.dtype)) id_cond = torch.stack(id_cond_list, dim=1).to(self.dtype) id_vit_hidden = id_vit_hidden_list[0] for i in range(1, len(image_list)): for j, x in enumerate(id_vit_hidden_list[i]): id_vit_hidden[j] = torch.cat([id_vit_hidden[j], x], dim=1).to(self.dtype) id_embedding = self.id_adapter(id_cond, id_vit_hidden) uncond_id_embedding = self.id_adapter(id_uncond, id_vit_hidden_uncond) if self.offload: self.face_helper.face_det.to('cpu') self.id_adapter.to('cpu') self.clip_vision_model.to('cpu') # return id_embedding debug(f'PulID embedding: cond={id_embedding.shape} uncond={uncond_id_embedding.shape}') return uncond_id_embedding, id_embedding def set_progress_bar_config(self, bar_format: str = None, ncols: int = 80, colour: str = None): import functools from tqdm.auto import trange as trange_orig import pulid_sampling pulid_sampling.trange = functools.partial(trange_orig, bar_format=bar_format, ncols=ncols, colour=colour) def sample(self, x, sigma, **extra_args): t = self.timestep(sigma) x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data**2) ** 0.5 cfg_scale = extra_args['cfg_scale'] # debug(f'PulID sample start: step={self.step+1} x={x.shape} dtype={x.dtype} timestep={t.item()} sigma={sigma.shape} cfg={cfg_scale} args={extra_args.keys()}') eps_positive = self.pipe.unet(x_ddim_space, t, return_dict=False, **extra_args['positive'])[0] eps_negative = self.pipe.unet(x_ddim_space, t, return_dict=False, **extra_args['negative'])[0] noise_pred = eps_negative + cfg_scale * (eps_positive - eps_negative) latent = x - noise_pred * sigma[:, None, None, None] if self.callback_on_step_end is not None: self.step += 1 self.callback_on_step_end(self.pipe, step=self.step, timestep=t, kwargs={ 'latents': latent }) # debug(f'PulID sample end: step={self.step} x={latent.shape} dtype={x.dtype} min={torch.amin(latent)} max={torch.amax(latent)}') return latent def init_latent(self, seed, size, image, mask_image, strength, width, height): # pylint: disable=unused-argument # standard txt2img will full noise noise = torch.randn((size[0], 4, size[1] // 8, size[2] // 8), device="cpu", generator=torch.manual_seed(seed)) noise = noise.to(dtype=self.pipe.unet.dtype, device=self.device) if strength > 0 and image is not None: image = self.pipe.image_processor.preprocess(image) if mask_image is not None: # Inpaint latents = self.pipe.prepare_latents(1, # batch_size, self.pipe.vae.config.latent_channels, # num_channels_latents height, width, noise.dtype, noise.device, None, # generator latents=None, image=image, timestep=1000, is_strength_max=False, add_noise=False, return_noise=False, return_image_latents=False, ) latents = latents[0] debug(f'PulID noise: op=inpaint latent={latents.shape} image={image} mask={mask_image} dtype={latents.dtype}') else: # img2img latents = self.pipe.prepare_latents(image, None, # timestep (not needed) 1, # batch_size 1, # num_images_per_prompt noise.dtype, noise.device, None, # generator False, # add_noise ) debug(f'PulID noise: op=img2img latent={latents.shape} image={image} dtype={latents.dtype}') else: latents = torch.zeros_like(noise) debug(f'PulID noise: op=txt2img latent={latents.shape} dtype={latents.dtype}') return latents, noise def __call__( self, prompt: str='', negative_prompt: str='', width: int=1024, height: int=1024, guidance_scale: float=7.0, num_inference_steps: int=50, seed: int=-1, image: np.ndarray=None, mask_image: np.ndarray=None, strength: float=0.3, id_embedding=None, uncond_id_embedding=None, id_scale: float=1.0, output_type: str='pil', callback_on_step_end=None, ): debug(f'PulID call: width={width} height={height} cfg={guidance_scale} steps={num_inference_steps} seed={seed} strength={strength} id_scale={id_scale} output={output_type}') self.step = 0 # pylint: disable=attribute-defined-outside-init self.callback_on_step_end = callback_on_step_end # pylint: disable=attribute-defined-outside-init if isinstance(image, list) and len(image) > 0 and isinstance(image[0], Image.Image): if image[0].width != width or image[0].height != height: # override width/height if different width, height = image[0].width, image[0].height size = (1, height, width) # sigmas sigmas = self.get_sigmas_karras(num_inference_steps).to(self.device) if image is not None and strength > 0: _timesteps, num_inference_steps = self.pipe.get_timesteps(num_inference_steps, strength, self.device, None) # denoising_start disabled sigmas = sigmas[-(num_inference_steps + 1):].to(self.device) # shorten sigmas in i2i debug(f'PulID sigmas: sigmas={sigmas.shape} dtype={sigmas.dtype}') # latents latent, noise = self.init_latent(seed, size, image, mask_image, strength, width, height) noisy_latent = latent + noise * sigmas[0].to(noise) debug(f'PulID noisy: latent={noisy_latent.shape} dtype={noisy_latent.dtype}') ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.pipe.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, ) add_time_ids = list((size[1], size[2]) + (0, 0) + (size[1], size[2])) add_time_ids = torch.tensor([add_time_ids], dtype=self.pipe.unet.dtype, device=self.device) add_neg_time_ids = add_time_ids.clone() sampler_kwargs = dict( cfg_scale=guidance_scale, positive=dict( encoder_hidden_states=prompt_embeds, added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}, cross_attention_kwargs={'id_embedding': id_embedding, 'id_scale': id_scale}, ), negative=dict( encoder_hidden_states=negative_prompt_embeds, added_cond_kwargs={"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids}, cross_attention_kwargs={'id_embedding': uncond_id_embedding, 'id_scale': id_scale}, ), ) if mask_image is not None: latent_mask = torch.Tensor(np.asarray(mask_image.convert("L").resize((noisy_latent.shape[-1], noisy_latent.shape[-2])))).reshape((noisy_latent.shape[-2], noisy_latent.shape[-1])) latent_mask /= latent_mask.max() mask_args = dict( latent=latent, latent_mask=latent_mask, noise=noise, sigmas=sigmas, ) else: mask_args = None # actual sampling loop latents = self.sampler(self.sample, noisy_latent, sigmas, extra_args=sampler_kwargs, disable=False, mask_args=mask_args) # process output latents = latents.to(dtype=self.pipe.vae.dtype, device=self.device) debug(f'PulID output: latent={latents.shape} dtype={latents.dtype}') if output_type == 'latent': images = self.pipe.image_processor.postprocess(latents, output_type='latent') elif output_type == 'np': images = self.pipe.image_processor.postprocess(latents, output_type='np') else: latents = latents / self.pipe.vae.config.scaling_factor images = self.pipe.vae.decode(latents).sample images = self.pipe.image_processor.postprocess(images, output_type='pil') debug(f'PulID output: type={type(images)} images={images.shape if hasattr(images, "shape") else images}') return StableDiffusionXLPipelineOutput(images) class StableDiffusionXLPuLIDPipelineImage(StableDiffusionXLPuLIDPipeline): def __init__(self, pipe: StableDiffusionXLPipeline, device: torch.device, sampler=None, cache_dir=None): # pylint: disable=useless-parent-delegation super().__init__(pipe, device, sampler, cache_dir) # we dont do anything special here, just having different class so task-type can be detected/assigned class StableDiffusionXLPuLIDPipelineInpaint(StableDiffusionXLPuLIDPipeline): def __init__(self, pipe: StableDiffusionXLPipeline, device: torch.device, sampler=None, cache_dir=None): # pylint: disable=useless-parent-delegation super().__init__(pipe, device, sampler, cache_dir) # we dont do anything special here, just having different class so task-type can be detected/assigned