mirror of https://github.com/vladmandic/automatic
887 lines
42 KiB
Python
887 lines
42 KiB
Python
### original <https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/pipeline.py>
|
|
|
|
import inspect
|
|
from typing import Any, Union
|
|
from collections.abc import Callable
|
|
import PIL
|
|
import torch
|
|
from transformers import CLIPImageProcessor
|
|
from safetensors import safe_open
|
|
from huggingface_hub.utils import validate_hf_hub_args
|
|
from diffusers import StableDiffusionXLPipeline
|
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
|
from diffusers.loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
|
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
|
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
|
from diffusers.utils import _get_model_file, USE_PEFT_BACKEND, deprecate, is_torch_xla_available, scale_lora_layers, unscale_lora_layers
|
|
|
|
if is_torch_xla_available():
|
|
import torch_xla.core.xla_model as xm
|
|
XLA_AVAILABLE = True
|
|
else:
|
|
XLA_AVAILABLE = False
|
|
|
|
from modules.face.photomaker_model_v1 import PhotoMakerIDEncoder
|
|
from modules.face.photomaker_model_v2 import PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken
|
|
|
|
PipelineImageInput = Union[
|
|
PIL.Image.Image,
|
|
torch.FloatTensor,
|
|
list[PIL.Image.Image],
|
|
list[torch.FloatTensor],
|
|
]
|
|
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
|
"""
|
|
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
|
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
|
"""
|
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
|
# rescale the results from guidance (fixes overexposure)
|
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
|
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
|
return noise_cfg
|
|
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
|
def retrieve_timesteps(
|
|
scheduler,
|
|
num_inference_steps: int | None = None,
|
|
device: str | torch.device | None = None,
|
|
timesteps: list[int] | None = None,
|
|
sigmas: list[float] | None = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
|
|
|
Args:
|
|
scheduler (`SchedulerMixin`):
|
|
The scheduler to get timesteps from.
|
|
num_inference_steps (`int`):
|
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
|
must be `None`.
|
|
device (`str` or `torch.device`, *optional*):
|
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
|
timesteps (`List[int]`, *optional*):
|
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
|
`num_inference_steps` and `sigmas` must be `None`.
|
|
sigmas (`List[float]`, *optional*):
|
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
|
`num_inference_steps` and `timesteps` must be `None`.
|
|
|
|
Returns:
|
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
|
second element is the number of inference steps.
|
|
"""
|
|
if timesteps is not None and sigmas is not None:
|
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
|
if timesteps is not None:
|
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
|
if not accepts_timesteps:
|
|
raise ValueError(
|
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
|
f" timestep schedules. Please check whether you are using the correct scheduler."
|
|
)
|
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
|
timesteps = scheduler.timesteps
|
|
num_inference_steps = len(timesteps)
|
|
elif sigmas is not None:
|
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
|
if not accept_sigmas:
|
|
raise ValueError(
|
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
|
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
|
)
|
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
|
timesteps = scheduler.timesteps
|
|
num_inference_steps = len(timesteps)
|
|
else:
|
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
|
timesteps = scheduler.timesteps
|
|
return timesteps, num_inference_steps
|
|
|
|
|
|
class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|
@validate_hf_hub_args
|
|
def load_photomaker_adapter(
|
|
self,
|
|
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
|
weight_name: str,
|
|
subfolder: str = '',
|
|
trigger_word: str = 'img',
|
|
pm_version: str = 'v2',
|
|
device: torch.device = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Parameters:
|
|
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
|
Can be either:
|
|
|
|
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
|
the Hub.
|
|
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
|
with [`ModelMixin.save_pretrained`].
|
|
- A [torch state
|
|
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
|
|
|
weight_name (`str`):
|
|
The weight name NOT the path to the weight.
|
|
|
|
subfolder (`str`, defaults to `""`):
|
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
|
|
|
trigger_word (`str`, *optional*, defaults to `"img"`):
|
|
The trigger word is used to identify the position of class word in the text prompt,
|
|
and it is recommended not to set it as a common word.
|
|
This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation.
|
|
"""
|
|
|
|
# Load the main state dict first.
|
|
cache_dir = kwargs.pop("cache_dir", None)
|
|
force_download = kwargs.pop("force_download", False)
|
|
proxies = kwargs.pop("proxies", None)
|
|
local_files_only = kwargs.pop("local_files_only", None)
|
|
token = kwargs.pop("token", None)
|
|
revision = kwargs.pop("revision", None)
|
|
|
|
user_agent = {
|
|
"file_type": "attn_procs_weights",
|
|
"framework": "pytorch",
|
|
}
|
|
|
|
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
|
model_file = _get_model_file(
|
|
pretrained_model_name_or_path_or_dict,
|
|
weights_name=weight_name,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
proxies=proxies,
|
|
local_files_only=local_files_only,
|
|
token=token,
|
|
revision=revision,
|
|
subfolder=subfolder,
|
|
user_agent=user_agent,
|
|
)
|
|
if weight_name.endswith(".safetensors"):
|
|
state_dict = {"id_encoder": {}, "lora_weights": {}}
|
|
with safe_open(model_file, framework="pt", device="cpu") as f:
|
|
for key in f.keys():
|
|
if key.startswith("id_encoder."):
|
|
state_dict["id_encoder"][key.replace("id_encoder.", "")] = f.get_tensor(key)
|
|
elif key.startswith("lora_weights."):
|
|
state_dict["lora_weights"][key.replace("lora_weights.", "")] = f.get_tensor(key)
|
|
else:
|
|
state_dict = torch.load(model_file, map_location="cpu")
|
|
else:
|
|
state_dict = pretrained_model_name_or_path_or_dict
|
|
|
|
keys = list(state_dict.keys())
|
|
if keys != ["id_encoder", "lora_weights"]:
|
|
raise ValueError("Required keys are (`id_encoder` and `lora_weights`) missing from the state dict.")
|
|
|
|
self.num_tokens =2 # pylint: disable=attribute-defined-outside-init
|
|
self.pm_version = pm_version # pylint: disable=attribute-defined-outside-init
|
|
self.trigger_word = trigger_word # pylint: disable=attribute-defined-outside-init
|
|
# load finetuned CLIP image encoder and fuse module here if it has not been registered to the pipeline yet
|
|
self.id_image_processor = CLIPImageProcessor() # pylint: disable=attribute-defined-outside-init
|
|
if pm_version == "v1": # PhotoMaker v1
|
|
id_encoder = PhotoMakerIDEncoder()
|
|
elif pm_version == "v2": # PhotoMaker v2
|
|
id_encoder = PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken()
|
|
else:
|
|
raise NotImplementedError(f"The PhotoMaker version [{pm_version}] does not support")
|
|
|
|
id_encoder.load_state_dict(state_dict["id_encoder"], strict=True)
|
|
id_encoder = id_encoder.to(device, dtype=self.unet.dtype)
|
|
self.id_encoder = id_encoder # pylint: disable=attribute-defined-outside-init
|
|
|
|
# load lora into models
|
|
self.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker")
|
|
|
|
# Add trigger word token
|
|
if self.tokenizer is not None:
|
|
self.tokenizer.add_tokens([self.trigger_word], special_tokens=True)
|
|
|
|
self.tokenizer_2.add_tokens([self.trigger_word], special_tokens=True)
|
|
|
|
|
|
def encode_prompt_with_trigger_word(
|
|
self,
|
|
prompt: str,
|
|
prompt_2: str | None = None,
|
|
device: torch.device | None = None,
|
|
num_images_per_prompt: int = 1,
|
|
do_classifier_free_guidance: bool = True,
|
|
negative_prompt: str | None = None,
|
|
negative_prompt_2: str | None = None,
|
|
prompt_embeds: torch.Tensor | None = None,
|
|
negative_prompt_embeds: torch.Tensor | None = None,
|
|
pooled_prompt_embeds: torch.Tensor | None = None,
|
|
negative_pooled_prompt_embeds: torch.Tensor | None = None,
|
|
lora_scale: float | None = None,
|
|
clip_skip: int | None = None,
|
|
### Added args
|
|
num_id_images: int = 1,
|
|
class_tokens_mask: torch.LongTensor | None = None,
|
|
):
|
|
device = device or self._execution_device
|
|
|
|
# set lora scale so that monkey patched LoRA
|
|
# function of text encoder can correctly access it
|
|
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
|
|
self._lora_scale = lora_scale # pylint: disable=attribute-defined-outside-init
|
|
|
|
# dynamically adjust the LoRA scale
|
|
if self.text_encoder is not None:
|
|
if not USE_PEFT_BACKEND:
|
|
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
|
else:
|
|
scale_lora_layers(self.text_encoder, lora_scale)
|
|
|
|
if self.text_encoder_2 is not None:
|
|
if not USE_PEFT_BACKEND:
|
|
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
|
else:
|
|
scale_lora_layers(self.text_encoder_2, lora_scale)
|
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
|
|
if prompt is not None:
|
|
batch_size = len(prompt)
|
|
else:
|
|
batch_size = prompt_embeds.shape[0]
|
|
|
|
# Find the token id of the trigger word
|
|
image_token_id = self.tokenizer_2.convert_tokens_to_ids(self.trigger_word)
|
|
|
|
# Define tokenizers and text encoders
|
|
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
|
text_encoders = (
|
|
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
|
)
|
|
|
|
if prompt_embeds is None:
|
|
prompt_2 = prompt_2 or prompt
|
|
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
|
|
|
# textual inversion: process multi-vector tokens if necessary
|
|
prompt_embeds_list = []
|
|
prompts = [prompt, prompt_2]
|
|
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False): # pylint: disable=redefined-argument-from-local
|
|
if isinstance(self, TextualInversionLoaderMixin):
|
|
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
|
|
|
text_inputs = tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=tokenizer.model_max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
text_input_ids = text_inputs.input_ids
|
|
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
|
text_input_ids, untruncated_ids
|
|
):
|
|
_removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
|
|
|
clean_index = 0
|
|
clean_input_ids = []
|
|
class_token_index = []
|
|
# Find out the corresponding class word token based on the newly added trigger word token
|
|
for _i, token_id in enumerate(text_input_ids.tolist()[0]):
|
|
if token_id == image_token_id:
|
|
class_token_index.append(clean_index - 1)
|
|
else:
|
|
clean_input_ids.append(token_id)
|
|
clean_index += 1
|
|
|
|
if len(class_token_index) != 1:
|
|
raise ValueError(
|
|
f"PhotoMaker currently does not support multiple trigger words in a single prompt.\
|
|
Trigger word: {self.trigger_word}, Prompt: {prompt}."
|
|
)
|
|
class_token_index = class_token_index[0]
|
|
|
|
# Expand the class word token and corresponding mask
|
|
class_token = clean_input_ids[class_token_index]
|
|
clean_input_ids = clean_input_ids[:class_token_index] + [class_token] * num_id_images * self.num_tokens + \
|
|
clean_input_ids[class_token_index+1:]
|
|
|
|
# Truncation or padding
|
|
max_len = tokenizer.model_max_length
|
|
if len(clean_input_ids) > max_len:
|
|
clean_input_ids = clean_input_ids[:max_len]
|
|
else:
|
|
clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * (
|
|
max_len - len(clean_input_ids)
|
|
)
|
|
|
|
class_tokens_mask = [True if class_token_index <= i < class_token_index+(num_id_images * self.num_tokens) else False \
|
|
for i in range(len(clean_input_ids))]
|
|
|
|
clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long).unsqueeze(0)
|
|
class_tokens_mask = torch.tensor(class_tokens_mask, dtype=torch.bool).unsqueeze(0)
|
|
|
|
prompt_embeds = text_encoder(clean_input_ids.to(device), output_hidden_states=True)
|
|
|
|
# We are only ALWAYS interested in the pooled output of the final text encoder
|
|
pooled_prompt_embeds = prompt_embeds[0]
|
|
if clip_skip is None:
|
|
prompt_embeds = prompt_embeds.hidden_states[-2]
|
|
else:
|
|
# "2" because SDXL always indexes from the penultimate layer.
|
|
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
|
|
|
prompt_embeds_list.append(prompt_embeds)
|
|
|
|
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
|
|
|
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
|
class_tokens_mask = class_tokens_mask.to(device=device)
|
|
# get unconditional embeddings for classifier free guidance
|
|
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt # pylint: disable=no-member
|
|
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
|
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
|
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
|
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
|
negative_prompt = negative_prompt or ""
|
|
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
|
|
|
# normalize str to list
|
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
|
negative_prompt_2 = (
|
|
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
|
)
|
|
|
|
uncond_tokens: list[str]
|
|
if prompt is not None and type(prompt) is not type(negative_prompt):
|
|
raise TypeError(
|
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
|
f" {type(prompt)}."
|
|
)
|
|
if batch_size != len(negative_prompt):
|
|
raise ValueError(
|
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
|
" the batch size of `prompt`."
|
|
)
|
|
uncond_tokens = [negative_prompt, negative_prompt_2]
|
|
|
|
negative_prompt_embeds_list = []
|
|
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders, strict=False): # pylint: disable=redefined-argument-from-local
|
|
if isinstance(self, TextualInversionLoaderMixin):
|
|
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
|
|
|
|
max_length = prompt_embeds.shape[1]
|
|
uncond_input = tokenizer(
|
|
negative_prompt,
|
|
padding="max_length",
|
|
max_length=max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
negative_prompt_embeds = text_encoder(
|
|
uncond_input.input_ids.to(device),
|
|
output_hidden_states=True,
|
|
)
|
|
# We are only ALWAYS interested in the pooled output of the final text encoder
|
|
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
|
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
|
|
|
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
|
|
|
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
|
|
|
if self.text_encoder_2 is not None:
|
|
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
|
else:
|
|
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
|
|
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
|
|
|
if do_classifier_free_guidance:
|
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
|
seq_len = negative_prompt_embeds.shape[1]
|
|
|
|
if self.text_encoder_2 is not None:
|
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
|
else:
|
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
|
|
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
|
|
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
|
bs_embed * num_images_per_prompt, -1
|
|
)
|
|
if do_classifier_free_guidance:
|
|
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
|
bs_embed * num_images_per_prompt, -1
|
|
)
|
|
|
|
if self.text_encoder is not None:
|
|
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
|
# Retrieve the original scale by scaling back the LoRA layers
|
|
unscale_lora_layers(self.text_encoder, lora_scale)
|
|
|
|
if self.text_encoder_2 is not None:
|
|
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
|
# Retrieve the original scale by scaling back the LoRA layers
|
|
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
|
|
|
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, class_tokens_mask
|
|
|
|
@torch.no_grad()
|
|
def __call__(
|
|
self,
|
|
prompt: str | list[str] = None,
|
|
prompt_2: str | list[str] | None = None,
|
|
height: int | None = None,
|
|
width: int | None = None,
|
|
num_inference_steps: int = 50,
|
|
timesteps: list[int] = None,
|
|
sigmas: list[float] = None,
|
|
denoising_end: float | None = None,
|
|
guidance_scale: float = 5.0,
|
|
negative_prompt: str | list[str] | None = None,
|
|
negative_prompt_2: str | list[str] | None = None,
|
|
num_images_per_prompt: int | None = 1,
|
|
eta: float = 0.0,
|
|
generator: torch.Generator | list[torch.Generator] | None = None,
|
|
latents: torch.Tensor | None = None,
|
|
prompt_embeds: torch.Tensor | None = None,
|
|
negative_prompt_embeds: torch.Tensor | None = None,
|
|
pooled_prompt_embeds: torch.Tensor | None = None,
|
|
negative_pooled_prompt_embeds: torch.Tensor | None = None,
|
|
ip_adapter_image: PipelineImageInput | None = None,
|
|
ip_adapter_image_embeds: list[torch.Tensor] | None = None,
|
|
output_type: str | None = "pil",
|
|
return_dict: bool = True,
|
|
cross_attention_kwargs: dict[str, Any] | None = None,
|
|
guidance_rescale: float = 0.0,
|
|
original_size: tuple[int, int] | None = None,
|
|
crops_coords_top_left: tuple[int, int] = (0, 0),
|
|
target_size: tuple[int, int] | None = None,
|
|
negative_original_size: tuple[int, int] | None = None,
|
|
negative_crops_coords_top_left: tuple[int, int] = (0, 0),
|
|
negative_target_size: tuple[int, int] | None = None,
|
|
clip_skip: int | None = None,
|
|
callback_on_step_end: Callable[[int, int, dict], None] | PipelineCallback | MultiPipelineCallbacks | None = None,
|
|
callback_on_step_end_tensor_inputs: list[str] = None,
|
|
# Added parameters (for PhotoMaker)
|
|
input_id_images: PipelineImageInput = None,
|
|
start_merge_step: int = 10,
|
|
class_tokens_mask: torch.LongTensor | None = None,
|
|
id_embeds: torch.FloatTensor | None = None,
|
|
prompt_embeds_text_only: torch.FloatTensor | None = None,
|
|
pooled_prompt_embeds_text_only: torch.FloatTensor | None = None,
|
|
**kwargs,
|
|
):
|
|
r"""
|
|
Function invoked when calling the pipeline for generation.
|
|
Only the parameters introduced by PhotoMaker are discussed here.
|
|
For explanations of the previous parameters in StableDiffusionXLPipeline, please refer to https://github.com/huggingface/diffusers/blob/v0.25.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
|
|
|
|
Args:
|
|
input_id_images (`PipelineImageInput`, *optional*):
|
|
Input ID Image to work with PhotoMaker.
|
|
class_tokens_mask (`torch.LongTensor`, *optional*):
|
|
Pre-generated class token. When the `prompt_embeds` parameter is provided in advance, it is necessary to prepare the `class_tokens_mask` beforehand for marking out the position of class word.
|
|
prompt_embeds_text_only (`torch.FloatTensor`, *optional*):
|
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
|
provided, text embeddings will be generated from `prompt` input argument.
|
|
pooled_prompt_embeds_text_only (`torch.FloatTensor`, *optional*):
|
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
|
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
|
|
|
Returns:
|
|
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
|
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
|
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
|
"""
|
|
|
|
if callback_on_step_end_tensor_inputs is None:
|
|
callback_on_step_end_tensor_inputs = ["latents"]
|
|
callback = kwargs.pop("callback", None)
|
|
callback_steps = kwargs.pop("callback_steps", None)
|
|
|
|
if callback is not None:
|
|
deprecate(
|
|
"callback",
|
|
"1.0.0",
|
|
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
|
)
|
|
if callback_steps is not None:
|
|
deprecate(
|
|
"callback_steps",
|
|
"1.0.0",
|
|
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
|
)
|
|
|
|
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
|
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
|
|
|
# 0. Default height and width to unet
|
|
height = height or self.default_sample_size * self.vae_scale_factor
|
|
width = width or self.default_sample_size * self.vae_scale_factor
|
|
|
|
original_size = original_size or (height, width)
|
|
target_size = target_size or (height, width)
|
|
|
|
# 1. Check inputs. Raise error if not correct
|
|
self.check_inputs(
|
|
prompt,
|
|
prompt_2,
|
|
height,
|
|
width,
|
|
callback_steps,
|
|
negative_prompt,
|
|
negative_prompt_2,
|
|
prompt_embeds,
|
|
negative_prompt_embeds,
|
|
pooled_prompt_embeds,
|
|
negative_pooled_prompt_embeds,
|
|
ip_adapter_image,
|
|
ip_adapter_image_embeds,
|
|
callback_on_step_end_tensor_inputs,
|
|
)
|
|
|
|
self._guidance_scale = guidance_scale # pylint: disable=attribute-defined-outside-init
|
|
self._guidance_rescale = guidance_rescale # pylint: disable=attribute-defined-outside-init
|
|
self._clip_skip = clip_skip # pylint: disable=attribute-defined-outside-init
|
|
self._cross_attention_kwargs = cross_attention_kwargs # pylint: disable=attribute-defined-outside-init
|
|
self._denoising_end = denoising_end # pylint: disable=attribute-defined-outside-init
|
|
self._interrupt = False # pylint: disable=attribute-defined-outside-init
|
|
|
|
if prompt_embeds is not None and class_tokens_mask is None:
|
|
raise ValueError(
|
|
"If `prompt_embeds` are provided, `class_tokens_mask` also have to be passed. Make sure to generate `class_tokens_mask` from the same tokenizer that was used to generate `prompt_embeds`."
|
|
)
|
|
# check the input id images
|
|
if input_id_images is None:
|
|
raise ValueError(
|
|
"Provide `input_id_images`. Cannot leave `input_id_images` undefined for PhotoMaker pipeline."
|
|
)
|
|
if not isinstance(input_id_images, list):
|
|
input_id_images = [input_id_images]
|
|
|
|
# 2. Define call parameters
|
|
if prompt is not None and isinstance(prompt, str):
|
|
batch_size = 1
|
|
elif prompt is not None and isinstance(prompt, list):
|
|
batch_size = len(prompt)
|
|
else:
|
|
batch_size = prompt_embeds.shape[0]
|
|
|
|
device = self._execution_device
|
|
|
|
# 3. Encode input prompt
|
|
lora_scale = (
|
|
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
|
)
|
|
|
|
num_id_images = len(input_id_images)
|
|
(
|
|
prompt_embeds,
|
|
_,
|
|
pooled_prompt_embeds,
|
|
_,
|
|
class_tokens_mask,
|
|
) = self.encode_prompt_with_trigger_word(
|
|
prompt=prompt,
|
|
prompt_2=prompt_2,
|
|
device=device,
|
|
num_id_images=num_id_images,
|
|
class_tokens_mask=class_tokens_mask,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
|
negative_prompt=negative_prompt,
|
|
negative_prompt_2=negative_prompt_2,
|
|
prompt_embeds=prompt_embeds,
|
|
negative_prompt_embeds=negative_prompt_embeds,
|
|
pooled_prompt_embeds=pooled_prompt_embeds,
|
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
|
lora_scale=lora_scale,
|
|
clip_skip=self.clip_skip,
|
|
)
|
|
|
|
# 4. Encode input prompt without the trigger word for delayed conditioning
|
|
# encode, remove trigger word token, then decode
|
|
tokens_text_only = self.tokenizer.encode(prompt, add_special_tokens=False)
|
|
trigger_word_token = self.tokenizer.convert_tokens_to_ids(self.trigger_word)
|
|
tokens_text_only.remove(trigger_word_token)
|
|
prompt_text_only = self.tokenizer.decode(tokens_text_only, add_special_tokens=False)
|
|
(
|
|
prompt_embeds_text_only,
|
|
negative_prompt_embeds,
|
|
pooled_prompt_embeds_text_only,
|
|
negative_pooled_prompt_embeds,
|
|
) = self.encode_prompt(
|
|
prompt=prompt_text_only,
|
|
prompt_2=prompt_2,
|
|
device=device,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
|
negative_prompt=negative_prompt,
|
|
negative_prompt_2=negative_prompt_2,
|
|
prompt_embeds=prompt_embeds_text_only,
|
|
negative_prompt_embeds=negative_prompt_embeds,
|
|
pooled_prompt_embeds=pooled_prompt_embeds_text_only,
|
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
|
lora_scale=lora_scale,
|
|
clip_skip=self.clip_skip,
|
|
)
|
|
|
|
# 5. Prepare timesteps
|
|
timesteps, num_inference_steps = retrieve_timesteps(
|
|
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
|
)
|
|
|
|
# 6. Prepare the input ID images
|
|
dtype = next(self.id_encoder.parameters()).dtype
|
|
if not isinstance(input_id_images[0], torch.Tensor):
|
|
id_pixel_values = self.id_image_processor(input_id_images, return_tensors="pt").pixel_values # pylint: disable=used-before-assignment
|
|
|
|
id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # pylint: disable=used-before-assignment
|
|
|
|
# 7. Get the update text embedding with the stacked ID embedding
|
|
if id_embeds is not None:
|
|
id_embeds = id_embeds.unsqueeze(0).to(device=device, dtype=dtype)
|
|
prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds)
|
|
else:
|
|
prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask)
|
|
|
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
|
|
|
# 8. Prepare latent variables
|
|
num_channels_latents = self.unet.config.in_channels
|
|
latents = self.prepare_latents(
|
|
batch_size * num_images_per_prompt,
|
|
num_channels_latents,
|
|
height,
|
|
width,
|
|
prompt_embeds.dtype,
|
|
device,
|
|
generator,
|
|
latents,
|
|
)
|
|
|
|
# 9. Prepare extra step kwargs.
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
|
|
|
# 10. Prepare added time ids & embeddings
|
|
add_text_embeds = pooled_prompt_embeds
|
|
if self.text_encoder_2 is None:
|
|
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
|
else:
|
|
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
|
|
|
add_time_ids = self._get_add_time_ids(
|
|
original_size,
|
|
crops_coords_top_left,
|
|
target_size,
|
|
dtype=prompt_embeds.dtype,
|
|
text_encoder_projection_dim=text_encoder_projection_dim,
|
|
)
|
|
if negative_original_size is not None and negative_target_size is not None:
|
|
negative_add_time_ids = self._get_add_time_ids(
|
|
negative_original_size,
|
|
negative_crops_coords_top_left,
|
|
negative_target_size,
|
|
dtype=prompt_embeds.dtype,
|
|
text_encoder_projection_dim=text_encoder_projection_dim,
|
|
)
|
|
else:
|
|
negative_add_time_ids = add_time_ids
|
|
|
|
if self.do_classifier_free_guidance:
|
|
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
|
|
|
prompt_embeds = prompt_embeds.to(device)
|
|
add_text_embeds = add_text_embeds.to(device)
|
|
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
|
|
|
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
|
image_embeds = self.prepare_ip_adapter_image_embeds(
|
|
ip_adapter_image,
|
|
ip_adapter_image_embeds,
|
|
device,
|
|
batch_size * num_images_per_prompt,
|
|
self.do_classifier_free_guidance,
|
|
)
|
|
|
|
# 11. Denoising loop
|
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
|
|
|
# 11.1 Apply denoising_end
|
|
if (
|
|
self.denoising_end is not None
|
|
and isinstance(self.denoising_end, float)
|
|
and self.denoising_end > 0
|
|
and self.denoising_end < 1
|
|
):
|
|
discrete_timestep_cutoff = int(
|
|
round(
|
|
self.scheduler.config.num_train_timesteps # pylint: disable=no-member
|
|
- (self.denoising_end * self.scheduler.config.num_train_timesteps) # pylint: disable=no-member
|
|
)
|
|
)
|
|
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
|
timesteps = timesteps[:num_inference_steps]
|
|
|
|
# 12. Optionally get Guidance Scale Embedding
|
|
timestep_cond = None
|
|
if self.unet.config.time_cond_proj_dim is not None:
|
|
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
|
timestep_cond = self.get_guidance_scale_embedding(
|
|
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
|
).to(device=device, dtype=latents.dtype)
|
|
|
|
self._num_timesteps = len(timesteps) # pylint: disable=attribute-defined-outside-init
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
for i, t in enumerate(timesteps):
|
|
if self.interrupt:
|
|
continue
|
|
|
|
# expand the latents if we are doing classifier free guidance
|
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
|
|
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
|
|
if i <= start_merge_step:
|
|
current_prompt_embeds = torch.cat(
|
|
[negative_prompt_embeds, prompt_embeds_text_only], dim=0
|
|
) if self.do_classifier_free_guidance else prompt_embeds_text_only
|
|
add_text_embeds = torch.cat(
|
|
[negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0
|
|
) if self.do_classifier_free_guidance else pooled_prompt_embeds_text_only
|
|
else:
|
|
current_prompt_embeds = torch.cat(
|
|
[negative_prompt_embeds, prompt_embeds], dim=0
|
|
) if self.do_classifier_free_guidance else prompt_embeds
|
|
add_text_embeds = torch.cat(
|
|
[negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0
|
|
) if self.do_classifier_free_guidance else pooled_prompt_embeds
|
|
|
|
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
|
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
|
added_cond_kwargs["image_embeds"] = image_embeds
|
|
|
|
# predict the noise residual
|
|
noise_pred = self.unet(
|
|
latent_model_input,
|
|
t,
|
|
encoder_hidden_states=current_prompt_embeds,
|
|
timestep_cond=timestep_cond,
|
|
cross_attention_kwargs=self.cross_attention_kwargs,
|
|
added_cond_kwargs=added_cond_kwargs,
|
|
return_dict=False,
|
|
)[0]
|
|
|
|
# perform guidance
|
|
if self.do_classifier_free_guidance:
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
|
|
|
# compute the previous noisy sample x_t -> x_t-1
|
|
latents_dtype = latents.dtype
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
|
if latents.dtype != latents_dtype:
|
|
if torch.backends.mps.is_available():
|
|
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
|
latents = latents.to(latents_dtype)
|
|
|
|
if callback_on_step_end is not None:
|
|
callback_kwargs = {}
|
|
for k in callback_on_step_end_tensor_inputs:
|
|
callback_kwargs[k] = locals()[k]
|
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
|
|
|
latents = callback_outputs.pop("latents", latents)
|
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
|
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
|
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
|
negative_pooled_prompt_embeds = callback_outputs.pop(
|
|
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
|
)
|
|
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
|
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
|
|
|
|
# call the callback, if provided
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
progress_bar.update()
|
|
if callback is not None and i % callback_steps == 0:
|
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
|
callback(step_idx, t, latents)
|
|
|
|
if XLA_AVAILABLE:
|
|
xm.mark_step() # pylint: disable=possibly-used-before-assignment
|
|
|
|
if output_type != "latent":
|
|
# make sure the VAE is in float32 mode, as it overflows in float16
|
|
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
|
|
|
if needs_upcasting:
|
|
self.upcast_vae()
|
|
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
|
elif latents.dtype != self.vae.dtype:
|
|
if torch.backends.mps.is_available():
|
|
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
|
self.vae = self.vae.to(latents.dtype) # pylint: disable=attribute-defined-outside-init
|
|
|
|
# unscale/denormalize the latents
|
|
# denormalize with the mean and std if available and not None
|
|
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
|
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
|
if has_latents_mean and has_latents_std:
|
|
latents_mean = (
|
|
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
|
)
|
|
latents_std = (
|
|
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
|
)
|
|
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
|
else:
|
|
latents = latents / self.vae.config.scaling_factor
|
|
|
|
image = self.vae.decode(latents, return_dict=False)[0]
|
|
|
|
# cast back to fp16 if needed
|
|
if needs_upcasting:
|
|
self.vae.to(dtype=torch.float16)
|
|
else:
|
|
image = latents
|
|
return StableDiffusionXLPipelineOutput(images=image)
|
|
|
|
# apply watermark if available
|
|
# if self.watermark is not None:
|
|
# image = self.watermark.apply_watermark(image)
|
|
|
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
|
|
|
# Offload all models
|
|
self.maybe_free_model_hooks()
|
|
|
|
if not return_dict:
|
|
return (image,)
|
|
|
|
return StableDiffusionXLPipelineOutput(images=image)
|