diff --git a/README.md b/README.md index aefb5fb..6f0bc37 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ we present IP-Adapter, an effective and lightweight adapter to achieve image pro ![arch](assets/figs/fig1.png) ## Release +- [2023/11/10] 🔥 Add an updated version of IP-Adapter-Face. The demo is [here](ip_adapter-full-face_demo.ipynb). - [2023/11/05] 🔥 Add text-to-image [demo](ip_adapter_t2i_demo.ipynb) with IP-Adapter and [Kandinsky 2.2 Prior](https://huggingface.co/kandinsky-community/kandinsky-2-2-prior) - [2023/11/02] Support [safetensors](https://github.com/huggingface/safetensors) - [2023/9/08] 🔥 Update a new version of IP-Adapter with SDXL_1.0. More information can be found [here](#sdxl_10). diff --git a/ip_adapter-full-face_demo.ipynb b/ip_adapter-full-face_demo.ipynb new file mode 100644 index 0000000..e8e9e24 --- /dev/null +++ b/ip_adapter-full-face_demo.ipynb @@ -0,0 +1,340 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "411c59b3-f177-4a10-8925-d931ce572eaa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-11-10 10:44:18,479] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + } + ], + "source": [ + "import torch\n", + "from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL\n", + "from PIL import Image\n", + "\n", + "from ip_adapter import IPAdapterFull" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "6b6dc69c-192d-4d74-8b1e-f0d9ccfbdb49", + "metadata": {}, + "outputs": [], + "source": [ + "base_model_path = \"SG161222/Realistic_Vision_V4.0_noVAE\"\n", + "vae_model_path = \"stabilityai/sd-vae-ft-mse\"\n", + "image_encoder_path = \"models/image_encoder/\"\n", + "ip_ckpt = \"models/ip-adapter-full-face_sd15.bin\"\n", + "device = \"cuda\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "63ec542f-8474-4f38-9457-073425578073", + "metadata": {}, + "outputs": [], + "source": [ + "def image_grid(imgs, rows, cols):\n", + " assert len(imgs) == rows*cols\n", + "\n", + " w, h = imgs[0].size\n", + " grid = Image.new('RGB', size=(cols*w, rows*h))\n", + " grid_w, grid_h = grid.size\n", + " \n", + " for i, img in enumerate(imgs):\n", + " grid.paste(img, box=(i%cols*w, i//cols*h))\n", + " return grid\n", + "\n", + "noise_scheduler = DDIMScheduler(\n", + " num_train_timesteps=1000,\n", + " beta_start=0.00085,\n", + " beta_end=0.012,\n", + " beta_schedule=\"scaled_linear\",\n", + " clip_sample=False,\n", + " set_alpha_to_one=False,\n", + " steps_offset=1,\n", + ")\n", + "vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3849f9d0-5f68-4a49-9190-69dd50720cae", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b46fa33e617649a5b8ecd972568eec5b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading pipeline components...: 0%| | 0/5 [00:00" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# read image prompt\n", + "image = Image.open(\"assets/images/ai_face2.png\")\n", + "image.resize((256, 256))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "70e77d65-262f-415f-9cbd-057d57c4222d", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fba0b8cbeffb4a5e87aed5b7eecb0612", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/50 [00:00" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# use face as image prompt\n", + "images = ip_model.generate(\n", + " pil_image=image, num_samples=4, prompt=\"A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower\",\n", + " scale=0.7, width=512, height=704, num_inference_steps=50, seed=42)\n", + "grid = image_grid(images, 1, 4)\n", + "grid" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d83df45f-717d-4bb3-a5fd-0ea30930a431", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6fa892b865bc41f295aecd75f9b8c734", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/50 [00:00" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Use a lower scale to mix faces\n", + "images = ip_model.generate(\n", + " pil_image=image, num_samples=4, prompt=\"photo of Einstein wearing colorful casual shirt in a garden\",\n", + " scale=0.4, width=512, height=704, num_inference_steps=50, seed=42)\n", + "grid = image_grid(images, 1, 4)\n", + "grid" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "10d6359e-6eb3-432a-a890-b814c505d005", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "afaebbadeb9b4e7d9bebb0dd59035905", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading pipeline components...: 0%| | 0/5 [00:00" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# use face as image prompt\n", + "images = ip_model.generate(\n", + " pil_image=image, num_samples=4, prompt=\"A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower\",\n", + " scale=0.6, width=512, height=704, num_inference_steps=50, seed=42)\n", + "grid = image_grid(images, 1, 4)\n", + "grid" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ip_adapter/__init__.py b/ip_adapter/__init__.py index 301128c..3b1f1ff 100644 --- a/ip_adapter/__init__.py +++ b/ip_adapter/__init__.py @@ -1,8 +1,9 @@ -from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL +from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull __all__ = [ "IPAdapter", "IPAdapterPlus", "IPAdapterPlusXL", "IPAdapterXL", + "IPAdapterFull", ] diff --git a/ip_adapter/ip_adapter.py b/ip_adapter/ip_adapter.py index 0060d7f..f357b17 100644 --- a/ip_adapter/ip_adapter.py +++ b/ip_adapter/ip_adapter.py @@ -45,6 +45,23 @@ class ImageProjModel(torch.nn.Module): return clip_extra_context_tokens +class MLPProjModel(torch.nn.Module): + """SD model with image prompt""" + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), + torch.nn.GELU(), + torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), + torch.nn.LayerNorm(cross_attention_dim) + ) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + class IPAdapter: def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): self.device = device @@ -176,14 +193,13 @@ class IPAdapter: uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) with torch.inference_mode(): - prompt_embeds = self.pipe._encode_prompt( + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) - negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2) prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) @@ -295,6 +311,17 @@ class IPAdapterPlus(IPAdapter): return image_prompt_embeds, uncond_image_prompt_embeds +class IPAdapterFull(IPAdapterPlus): + """IP-Adapter with full features""" + + def init_proj(self): + image_proj_model = MLPProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.hidden_size, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + class IPAdapterPlusXL(IPAdapter): """SDXL"""