diff --git a/ip_adapter/attention_processor.py b/ip_adapter/attention_processor.py index 07fd6d8..fecc7ed 100644 --- a/ip_adapter/attention_processor.py +++ b/ip_adapter/attention_processor.py @@ -162,7 +162,6 @@ class IPAttnProcessor(nn.Module): ip_value = attn.head_to_batch_dim(ip_value) ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - self.attn_map = ip_attention_probs ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) @@ -379,9 +378,6 @@ class IPAttnProcessor2_0(torch.nn.Module): ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) - with torch.no_grad(): - self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) - #print(self.attn_map.shape) ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) diff --git a/ip_adapter/utils.py b/ip_adapter/utils.py index 3c99cd1..09df7ab 100644 --- a/ip_adapter/utils.py +++ b/ip_adapter/utils.py @@ -50,9 +50,10 @@ def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detac for name, attn_map in attn_maps.items(): attn_map = attn_map.cpu() if detach else attn_map - attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze() - attn_map = upscale(attn_map, image_size) - net_attn_maps.append(attn_map) + attn_map = torch.chunk(attn_map, batch_size)[idx] # (20, 32*32, 77) -> (10, 32*32, 77) # negative & positive CFG + + attn_map = upscale(attn_map, image_size) # (10,32*32,77) -> (77,64*64) + net_attn_maps.append(attn_map) # (10,32*32,77) -> (77,64*64) net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) diff --git a/visual_attnmap.ipynb b/visual_attnmap.ipynb new file mode 100644 index 0000000..bb1338d --- /dev/null +++ b/visual_attnmap.ipynb @@ -0,0 +1,367 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "39f9cc7a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/root/anaconda3/lib/python3.9/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " _torch_pytree._register_pytree_node(\n", + "/root/anaconda3/lib/python3.9/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " _torch_pytree._register_pytree_node(\n", + "/root/anaconda3/lib/python3.9/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " torch.utils._pytree._register_pytree_node(\n", + "/root/anaconda3/lib/python3.9/site-packages/scipy/__init__.py:155: UserWarning: A NumPy version >=1.18.5 and <1.25.0 is required for this version of SciPy (detected version 1.26.3\n", + " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n" + ] + } + ], + "source": [ + "import torch\n", + "from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL\n", + "from PIL import Image\n", + "import copy\n", + "\n", + "from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDPlus, IPAdapterFaceID\n", + "from insightface.app import FaceAnalysis\n", + "from insightface.model_zoo.arcface_onnx import ArcFaceONNX\n", + "from insightface.utils import face_align\n", + "from numpy.linalg import norm as l2norm\n", + "import cv2\n", + "from ip_adapter.utils import register_cross_attention_hook, get_net_attn_map, attnmaps2images" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0d290971", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/root/anaconda3/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:69: UserWarning: Specified provider 'CUDAExecutionProvider' is not in available provider names.Available providers: 'AzureExecutionProvider, CPUExecutionProvider'\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n", + "find model: /root/.insightface/models/buffalo_l/1k3d68.onnx landmark_3d_68 ['None', 3, 192, 192] 0.0 1.0\n", + "Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n", + "find model: /root/.insightface/models/buffalo_l/2d106det.onnx landmark_2d_106 ['None', 3, 192, 192] 0.0 1.0\n", + "Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n", + "find model: /root/.insightface/models/buffalo_l/det_10g.onnx detection [1, 3, '?', '?'] 127.5 128.0\n", + "Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n", + "find model: /root/.insightface/models/buffalo_l/genderage.onnx genderage ['None', 3, 96, 96] 0.0 1.0\n", + "Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n", + "find model: /root/.insightface/models/buffalo_l/w600k_r50.onnx recognition ['None', 3, 112, 112] 127.5 127.5\n", + "set det-size: (640, 640)\n" + ] + } + ], + "source": [ + "app = FaceAnalysis(name=\"buffalo_l\", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])\n", + "app.prepare(ctx_id=0, det_size=(640, 640))\n", + "\n", + "v2 = False\n", + "base_model_path = \"SG161222/Realistic_Vision_V4.0_noVAE\"\n", + "vae_model_path = \"models/sd-vae-ft-mse\"\n", + "image_encoder_path = \"IP-Adapter/models/image_encoder/\"\n", + "plus_ip_ckpt = \"IP-Adapter-FaceID/ip-adapter-faceid-plusv2_sd15.bin\"\n", + "ip_ckpt = \"IP-Adapter-FaceID/ip-adapter-faceid_sd15.bin\"\n", + "device = \"cuda\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f20eae92", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n", + "```\n", + "pip install accelerate\n", + "```\n", + ".\n", + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n", + "```\n", + "pip install accelerate\n", + "```\n", + ".\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b67a335943814268bb1b587b8582154d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading pipeline components...: 0%| | 0/5 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "#axes[0].imshow(attn_hot[0], cmap='gray')\n", + "display_images = [cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)] + attn_hot + [images[0]]\n", + "fig, axes = plt.subplots(1, len(display_images), figsize=(12, 4))\n", + "for axe, image in zip(axes, display_images):\n", + " axe.imshow(image, cmap='gray')\n", + " axe.axis('off')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fead1786", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b1741aa4e3624e66b07d181531d14596", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/30 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "display_images = [cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)] + attn_hot + [images[0]]\n", + "fig, axes = plt.subplots(1, len(display_images), figsize=(12, 4))\n", + "for axe, image in zip(axes, display_images):\n", + " axe.imshow(image, cmap='gray')\n", + " axe.axis('off')\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}