diff --git a/ip_adapter/attention_processor_faceid.py b/ip_adapter/attention_processor_faceid.py index e314099..8d70eaf 100644 --- a/ip_adapter/attention_processor_faceid.py +++ b/ip_adapter/attention_processor_faceid.py @@ -183,6 +183,7 @@ class LoRAIPAttnProcessor(nn.Module): ip_value = attn.head_to_batch_dim(ip_value) ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + self.attn_map = ip_attention_probs ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) @@ -201,4 +202,4 @@ class LoRAIPAttnProcessor(nn.Module): hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states \ No newline at end of file + return hidden_states diff --git a/ip_adapter/utils.py b/ip_adapter/utils.py index 9a105f3..09df7ab 100644 --- a/ip_adapter/utils.py +++ b/ip_adapter/utils.py @@ -1,5 +1,82 @@ +import torch import torch.nn.functional as F +import numpy as np +from PIL import Image +attn_maps = {} +def hook_fn(name): + def forward_hook(module, input, output): + if hasattr(module.processor, "attn_map"): + attn_maps[name] = module.processor.attn_map + del module.processor.attn_map + return forward_hook + +def register_cross_attention_hook(unet): + for name, module in unet.named_modules(): + if name.split('.')[-1].startswith('attn2'): + module.register_forward_hook(hook_fn(name)) + + return unet + +def upscale(attn_map, target_size): + attn_map = torch.mean(attn_map, dim=0) + attn_map = attn_map.permute(1,0) + temp_size = None + + for i in range(0,5): + scale = 2 ** i + if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64: + temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8)) + break + + assert temp_size is not None, "temp_size cannot is None" + + attn_map = attn_map.view(attn_map.shape[0], *temp_size) + + attn_map = F.interpolate( + attn_map.unsqueeze(0).to(dtype=torch.float32), + size=target_size, + mode='bilinear', + align_corners=False + )[0] + + attn_map = torch.softmax(attn_map, dim=0) + return attn_map +def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True): + + idx = 0 if instance_or_negative else 1 + net_attn_maps = [] + + for name, attn_map in attn_maps.items(): + attn_map = attn_map.cpu() if detach else attn_map + attn_map = torch.chunk(attn_map, batch_size)[idx] # (20, 32*32, 77) -> (10, 32*32, 77) # negative & positive CFG + + attn_map = upscale(attn_map, image_size) # (10,32*32,77) -> (77,64*64) + net_attn_maps.append(attn_map) # (10,32*32,77) -> (77,64*64) + + net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) + + return net_attn_maps + +def attnmaps2images(net_attn_maps): + + #total_attn_scores = 0 + images = [] + + for attn_map in net_attn_maps: + attn_map = attn_map.cpu().numpy() + #total_attn_scores += attn_map.mean().item() + + normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255 + normalized_attn_map = normalized_attn_map.astype(np.uint8) + #print("norm: ", normalized_attn_map.shape) + image = Image.fromarray(normalized_attn_map) + + #image = fix_save_attn_map(attn_map) + images.append(image) + + #print(total_attn_scores) + return images def is_torch2_available(): return hasattr(F, "scaled_dot_product_attention") diff --git a/visual_attnmap.ipynb b/visual_attnmap.ipynb new file mode 100644 index 0000000..c67478c --- /dev/null +++ b/visual_attnmap.ipynb @@ -0,0 +1,310 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "39f9cc7a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/root/anaconda3/lib/python3.9/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " _torch_pytree._register_pytree_node(\n", + "/root/anaconda3/lib/python3.9/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " _torch_pytree._register_pytree_node(\n", + "/root/anaconda3/lib/python3.9/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " torch.utils._pytree._register_pytree_node(\n", + "/root/anaconda3/lib/python3.9/site-packages/scipy/__init__.py:155: UserWarning: A NumPy version >=1.18.5 and <1.25.0 is required for this version of SciPy (detected version 1.26.3\n", + " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n" + ] + } + ], + "source": [ + "import torch\n", + "from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL\n", + "from PIL import Image\n", + "import copy\n", + "\n", + "from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDPlus, IPAdapterFaceID\n", + "from insightface.app import FaceAnalysis\n", + "from insightface.model_zoo.arcface_onnx import ArcFaceONNX\n", + "from insightface.utils import face_align\n", + "from numpy.linalg import norm as l2norm\n", + "import cv2\n", + "from ip_adapter.utils import register_cross_attention_hook, get_net_attn_map, attnmaps2images" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0d290971", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/root/anaconda3/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:69: UserWarning: Specified provider 'CUDAExecutionProvider' is not in available provider names.Available providers: 'AzureExecutionProvider, CPUExecutionProvider'\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n", + "find model: /root/.insightface/models/buffalo_l/1k3d68.onnx landmark_3d_68 ['None', 3, 192, 192] 0.0 1.0\n", + "Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n", + "find model: /root/.insightface/models/buffalo_l/2d106det.onnx landmark_2d_106 ['None', 3, 192, 192] 0.0 1.0\n", + "Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n", + "find model: /root/.insightface/models/buffalo_l/det_10g.onnx detection [1, 3, '?', '?'] 127.5 128.0\n", + "Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n", + "find model: /root/.insightface/models/buffalo_l/genderage.onnx genderage ['None', 3, 96, 96] 0.0 1.0\n", + "Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n", + "find model: /root/.insightface/models/buffalo_l/w600k_r50.onnx recognition ['None', 3, 112, 112] 127.5 127.5\n", + "set det-size: (640, 640)\n" + ] + } + ], + "source": [ + "app = FaceAnalysis(name=\"buffalo_l\", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])\n", + "app.prepare(ctx_id=0, det_size=(640, 640))\n", + "\n", + "v2 = False\n", + "base_model_path = \"/dfs/comicai/zhiyuan.shi/models/SG161222/Realistic_Vision_V4.0_noVAE\"\n", + "vae_model_path = \"/dfs/comicai/zhiyuan.shi/models/sd-vae-ft-mse\"\n", + "image_encoder_path = \"/dfs/comicai/zhengbing.yao/models/IP-Adapter/models/image_encoder/\"\n", + "plus_ip_ckpt = \"/dfs/comicai/zhiyuan.shi/models/IP-Adapter-FaceID/ip-adapter-faceid-plusv2_sd15.bin\"\n", + "ip_ckpt = \"/dfs/comicai/zhiyuan.shi/models/IP-Adapter-FaceID/ip-adapter-faceid_sd15.bin\"\n", + "device = \"cuda\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f20eae92", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n", + "```\n", + "pip install accelerate\n", + "```\n", + ".\n", + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n", + "```\n", + "pip install accelerate\n", + "```\n", + ".\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0e813051f12a4679aef02732b1baacf9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading pipeline components...: 0%| | 0/5 [00:00