diff --git a/ip_adapter/attention_processor.py b/ip_adapter/attention_processor.py index fecc7ed..07fd6d8 100644 --- a/ip_adapter/attention_processor.py +++ b/ip_adapter/attention_processor.py @@ -162,6 +162,7 @@ class IPAttnProcessor(nn.Module): ip_value = attn.head_to_batch_dim(ip_value) ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + self.attn_map = ip_attention_probs ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) @@ -378,6 +379,9 @@ class IPAttnProcessor2_0(torch.nn.Module): ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) + with torch.no_grad(): + self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) + #print(self.attn_map.shape) ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) diff --git a/ip_adapter/utils.py b/ip_adapter/utils.py index 09df7ab..3c99cd1 100644 --- a/ip_adapter/utils.py +++ b/ip_adapter/utils.py @@ -50,10 +50,9 @@ def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detac for name, attn_map in attn_maps.items(): attn_map = attn_map.cpu() if detach else attn_map - attn_map = torch.chunk(attn_map, batch_size)[idx] # (20, 32*32, 77) -> (10, 32*32, 77) # negative & positive CFG - - attn_map = upscale(attn_map, image_size) # (10,32*32,77) -> (77,64*64) - net_attn_maps.append(attn_map) # (10,32*32,77) -> (77,64*64) + attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze() + attn_map = upscale(attn_map, image_size) + net_attn_maps.append(attn_map) net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) diff --git a/visual_attnmap.ipynb b/visual_attnmap.ipynb deleted file mode 100644 index bb1338d..0000000 --- a/visual_attnmap.ipynb +++ /dev/null @@ -1,367 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "39f9cc7a", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/root/anaconda3/lib/python3.9/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", - " _torch_pytree._register_pytree_node(\n", - "/root/anaconda3/lib/python3.9/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", - " _torch_pytree._register_pytree_node(\n", - "/root/anaconda3/lib/python3.9/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", - " torch.utils._pytree._register_pytree_node(\n", - "/root/anaconda3/lib/python3.9/site-packages/scipy/__init__.py:155: UserWarning: A NumPy version >=1.18.5 and <1.25.0 is required for this version of SciPy (detected version 1.26.3\n", - " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n" - ] - } - ], - "source": [ - "import torch\n", - "from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL\n", - "from PIL import Image\n", - "import copy\n", - "\n", - "from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDPlus, IPAdapterFaceID\n", - "from insightface.app import FaceAnalysis\n", - "from insightface.model_zoo.arcface_onnx import ArcFaceONNX\n", - "from insightface.utils import face_align\n", - "from numpy.linalg import norm as l2norm\n", - "import cv2\n", - "from ip_adapter.utils import register_cross_attention_hook, get_net_attn_map, attnmaps2images" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "0d290971", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/root/anaconda3/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:69: UserWarning: Specified provider 'CUDAExecutionProvider' is not in available provider names.Available providers: 'AzureExecutionProvider, CPUExecutionProvider'\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n", - "find model: /root/.insightface/models/buffalo_l/1k3d68.onnx landmark_3d_68 ['None', 3, 192, 192] 0.0 1.0\n", - "Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n", - "find model: /root/.insightface/models/buffalo_l/2d106det.onnx landmark_2d_106 ['None', 3, 192, 192] 0.0 1.0\n", - "Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n", - "find model: /root/.insightface/models/buffalo_l/det_10g.onnx detection [1, 3, '?', '?'] 127.5 128.0\n", - "Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n", - "find model: /root/.insightface/models/buffalo_l/genderage.onnx genderage ['None', 3, 96, 96] 0.0 1.0\n", - "Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n", - "find model: /root/.insightface/models/buffalo_l/w600k_r50.onnx recognition ['None', 3, 112, 112] 127.5 127.5\n", - "set det-size: (640, 640)\n" - ] - } - ], - "source": [ - "app = FaceAnalysis(name=\"buffalo_l\", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])\n", - "app.prepare(ctx_id=0, det_size=(640, 640))\n", - "\n", - "v2 = False\n", - "base_model_path = \"SG161222/Realistic_Vision_V4.0_noVAE\"\n", - "vae_model_path = \"models/sd-vae-ft-mse\"\n", - "image_encoder_path = \"IP-Adapter/models/image_encoder/\"\n", - "plus_ip_ckpt = \"IP-Adapter-FaceID/ip-adapter-faceid-plusv2_sd15.bin\"\n", - "ip_ckpt = \"IP-Adapter-FaceID/ip-adapter-faceid_sd15.bin\"\n", - "device = \"cuda\"" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "f20eae92", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n", - "```\n", - "pip install accelerate\n", - "```\n", - ".\n", - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n", - "```\n", - "pip install accelerate\n", - "```\n", - ".\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b67a335943814268bb1b587b8582154d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading pipeline components...: 0%| | 0/5 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "#axes[0].imshow(attn_hot[0], cmap='gray')\n", - "display_images = [cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)] + attn_hot + [images[0]]\n", - "fig, axes = plt.subplots(1, len(display_images), figsize=(12, 4))\n", - "for axe, image in zip(axes, display_images):\n", - " axe.imshow(image, cmap='gray')\n", - " axe.axis('off')\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "fead1786", - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b1741aa4e3624e66b07d181531d14596", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/30 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "\n", - "display_images = [cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)] + attn_hot + [images[0]]\n", - "fig, axes = plt.subplots(1, len(display_images), figsize=(12, 4))\n", - "for axe, image in zip(axes, display_images):\n", - " axe.imshow(image, cmap='gray')\n", - " axe.axis('off')\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}