From 6219530507cb4696a0496f10c0e5a4f1dbdc7672 Mon Sep 17 00:00:00 2001 From: xiaohu2015 Date: Wed, 27 Sep 2023 15:30:14 +0800 Subject: [PATCH] add ip-adapter-plus for sdxl --- ip_adapter-plus_sdxl_demo.ipynb | 216 ++++++++++++++++++++++++++++++++ ip_adapter/ip_adapter.py | 83 ++++++++++++ 2 files changed, 299 insertions(+) create mode 100644 ip_adapter-plus_sdxl_demo.ipynb diff --git a/ip_adapter-plus_sdxl_demo.ipynb b/ip_adapter-plus_sdxl_demo.ipynb new file mode 100644 index 0000000..bccba0b --- /dev/null +++ b/ip_adapter-plus_sdxl_demo.ipynb @@ -0,0 +1,216 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "411c59b3-f177-4a10-8925-d931ce572eaa", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from diffusers import StableDiffusionXLPipeline\n", + "from PIL import Image\n", + "\n", + "from ip_adapter import IPAdapterPlusXL" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "6b6dc69c-192d-4d74-8b1e-f0d9ccfbdb49", + "metadata": {}, + "outputs": [], + "source": [ + "base_model_path = \"SG161222/RealVisXL_V1.0\"\n", + "image_encoder_path = \"models/image_encoder\"\n", + "ip_ckpt = \"sdxl_models/ip-adapter-plus_sdxl_vit-h.bin\"\n", + "device = \"cuda\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "63ec542f-8474-4f38-9457-073425578073", + "metadata": {}, + "outputs": [], + "source": [ + "def image_grid(imgs, rows, cols):\n", + " assert len(imgs) == rows*cols\n", + "\n", + " w, h = imgs[0].size\n", + " grid = Image.new('RGB', size=(cols*w, rows*h))\n", + " grid_w, grid_h = grid.size\n", + " \n", + " for i, img in enumerate(imgs):\n", + " grid.paste(img, box=(i%cols*w, i//cols*h))\n", + " return grid" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3849f9d0-5f68-4a49-9190-69dd50720cae", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "04c9fbba762142cfa8862e0b66f02b7e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading pipeline components...: 0%| | 0/7 [00:00" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# read image prompt\n", + "image = Image.open(\"assets/images/woman.png\")\n", + "image.resize((512, 512))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b77f52de-a9e4-44e1-aeec-8165414f1273", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "11b0c78ca3394241ada86d2f40ccdb65", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/30 [00:00" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# generate image variations with only image prompt\n", + "num_samples = 2\n", + "images = ip_model.generate(pil_image=image, num_samples=num_samples, num_inference_steps=30, seed=42)\n", + "grid = image_grid(images, 1, num_samples)\n", + "grid" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "36ec1dce-7861-4ce2-90de-0de36bb28569", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0c7fc9c0dd04480faf6c448affd53254", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/30 [00:00" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# multimodal prompts\n", + "images = ip_model.generate(pil_image=image, num_samples=num_samples, num_inference_steps=30, seed=42,\n", + " prompt=\"best quality, high quality, wearing sunglasses on the beach\", scale=0.5)\n", + "grid = image_grid(images, 1, num_samples)\n", + "grid" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ip_adapter/ip_adapter.py b/ip_adapter/ip_adapter.py index b547ef4..d0781a2 100644 --- a/ip_adapter/ip_adapter.py +++ b/ip_adapter/ip_adapter.py @@ -248,3 +248,86 @@ class IPAdapterPlus(IPAdapter): uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2] uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) return image_prompt_embeds, uncond_image_prompt_embeds + + +class IPAdapterPlusXL(IPAdapter): + """SDXL""" + + def init_proj(self): + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4 + ).to(self.device, dtype=torch.float16) + return image_proj_model + + @torch.inference_mode() + def get_image_embeds(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=-1, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + if isinstance(pil_image, Image.Image): + num_prompts = 1 + else: + num_prompts = len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.pipe.encode_prompt( + prompt, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images