add visual

pull/221/head
root 2024-01-03 21:23:40 +09:00
parent 9d8960cbe8
commit 06d94b2535
3 changed files with 389 additions and 1 deletions

View File

@ -183,6 +183,7 @@ class LoRAIPAttnProcessor(nn.Module):
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
self.attn_map = ip_attention_probs
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
@ -201,4 +202,4 @@ class LoRAIPAttnProcessor(nn.Module):
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
return hidden_states

View File

@ -1,5 +1,82 @@
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
attn_maps = {}
def hook_fn(name):
def forward_hook(module, input, output):
if hasattr(module.processor, "attn_map"):
attn_maps[name] = module.processor.attn_map
del module.processor.attn_map
return forward_hook
def register_cross_attention_hook(unet):
for name, module in unet.named_modules():
if name.split('.')[-1].startswith('attn2'):
module.register_forward_hook(hook_fn(name))
return unet
def upscale(attn_map, target_size):
attn_map = torch.mean(attn_map, dim=0)
attn_map = attn_map.permute(1,0)
temp_size = None
for i in range(0,5):
scale = 2 ** i
if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
break
assert temp_size is not None, "temp_size cannot is None"
attn_map = attn_map.view(attn_map.shape[0], *temp_size)
attn_map = F.interpolate(
attn_map.unsqueeze(0).to(dtype=torch.float32),
size=target_size,
mode='bilinear',
align_corners=False
)[0]
attn_map = torch.softmax(attn_map, dim=0)
return attn_map
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
idx = 0 if instance_or_negative else 1
net_attn_maps = []
for name, attn_map in attn_maps.items():
attn_map = attn_map.cpu() if detach else attn_map
attn_map = torch.chunk(attn_map, batch_size)[idx] # (20, 32*32, 77) -> (10, 32*32, 77) # negative & positive CFG
attn_map = upscale(attn_map, image_size) # (10,32*32,77) -> (77,64*64)
net_attn_maps.append(attn_map) # (10,32*32,77) -> (77,64*64)
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
return net_attn_maps
def attnmaps2images(net_attn_maps):
#total_attn_scores = 0
images = []
for attn_map in net_attn_maps:
attn_map = attn_map.cpu().numpy()
#total_attn_scores += attn_map.mean().item()
normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
normalized_attn_map = normalized_attn_map.astype(np.uint8)
#print("norm: ", normalized_attn_map.shape)
image = Image.fromarray(normalized_attn_map)
#image = fix_save_attn_map(attn_map)
images.append(image)
#print(total_attn_scores)
return images
def is_torch2_available():
return hasattr(F, "scaled_dot_product_attention")

310
visual_attnmap.ipynb Normal file
View File

@ -0,0 +1,310 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "39f9cc7a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/anaconda3/lib/python3.9/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
" _torch_pytree._register_pytree_node(\n",
"/root/anaconda3/lib/python3.9/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
" _torch_pytree._register_pytree_node(\n",
"/root/anaconda3/lib/python3.9/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
" torch.utils._pytree._register_pytree_node(\n",
"/root/anaconda3/lib/python3.9/site-packages/scipy/__init__.py:155: UserWarning: A NumPy version >=1.18.5 and <1.25.0 is required for this version of SciPy (detected version 1.26.3\n",
" warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
]
}
],
"source": [
"import torch\n",
"from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL\n",
"from PIL import Image\n",
"import copy\n",
"\n",
"from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDPlus, IPAdapterFaceID\n",
"from insightface.app import FaceAnalysis\n",
"from insightface.model_zoo.arcface_onnx import ArcFaceONNX\n",
"from insightface.utils import face_align\n",
"from numpy.linalg import norm as l2norm\n",
"import cv2\n",
"from ip_adapter.utils import register_cross_attention_hook, get_net_attn_map, attnmaps2images"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0d290971",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/anaconda3/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:69: UserWarning: Specified provider 'CUDAExecutionProvider' is not in available provider names.Available providers: 'AzureExecutionProvider, CPUExecutionProvider'\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n",
"find model: /root/.insightface/models/buffalo_l/1k3d68.onnx landmark_3d_68 ['None', 3, 192, 192] 0.0 1.0\n",
"Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n",
"find model: /root/.insightface/models/buffalo_l/2d106det.onnx landmark_2d_106 ['None', 3, 192, 192] 0.0 1.0\n",
"Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n",
"find model: /root/.insightface/models/buffalo_l/det_10g.onnx detection [1, 3, '?', '?'] 127.5 128.0\n",
"Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n",
"find model: /root/.insightface/models/buffalo_l/genderage.onnx genderage ['None', 3, 96, 96] 0.0 1.0\n",
"Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n",
"find model: /root/.insightface/models/buffalo_l/w600k_r50.onnx recognition ['None', 3, 112, 112] 127.5 127.5\n",
"set det-size: (640, 640)\n"
]
}
],
"source": [
"app = FaceAnalysis(name=\"buffalo_l\", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])\n",
"app.prepare(ctx_id=0, det_size=(640, 640))\n",
"\n",
"v2 = False\n",
"base_model_path = \"/dfs/comicai/zhiyuan.shi/models/SG161222/Realistic_Vision_V4.0_noVAE\"\n",
"vae_model_path = \"/dfs/comicai/zhiyuan.shi/models/sd-vae-ft-mse\"\n",
"image_encoder_path = \"/dfs/comicai/zhengbing.yao/models/IP-Adapter/models/image_encoder/\"\n",
"plus_ip_ckpt = \"/dfs/comicai/zhiyuan.shi/models/IP-Adapter-FaceID/ip-adapter-faceid-plusv2_sd15.bin\"\n",
"ip_ckpt = \"/dfs/comicai/zhiyuan.shi/models/IP-Adapter-FaceID/ip-adapter-faceid_sd15.bin\"\n",
"device = \"cuda\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f20eae92",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n",
"```\n",
"pip install accelerate\n",
"```\n",
".\n",
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n",
"```\n",
"pip install accelerate\n",
"```\n",
".\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0e813051f12a4679aef02732b1baacf9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading pipeline components...: 0%| | 0/5 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"noise_scheduler = DDIMScheduler(\n",
" num_train_timesteps=1000,\n",
" beta_start=0.00085,\n",
" beta_end=0.012,\n",
" beta_schedule=\"scaled_linear\",\n",
" clip_sample=False,\n",
" set_alpha_to_one=False,\n",
" steps_offset=1,\n",
")\n",
"vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)\n",
"pipe = StableDiffusionPipeline.from_pretrained(\n",
" base_model_path,\n",
" torch_dtype=torch.float16,\n",
" scheduler=noise_scheduler,\n",
" vae=vae,\n",
" feature_extractor=None,\n",
" safety_checker=None\n",
")\n",
"pipe.unet = register_cross_attention_hook(pipe.unet)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "40372f70",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/anaconda3/lib/python3.9/site-packages/insightface/utils/transform.py:68: FutureWarning: `rcond` parameter will change to the default of machine precision times ``max(M, N)`` where M and N are the input matrix dimensions.\n",
"To use the future default and silence this warning we advise to pass `rcond=None`, to keep using the old, explicitly pass `rcond=-1`.\n",
" P = np.linalg.lstsq(X_homo, Y)[0].T # Affine matrix. 3 x 4\n"
]
}
],
"source": [
"# generate image\n",
"prompt = \"photo of a woman in red dress in a garden, white hair, happy\"\n",
"negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality, blurry\"\n",
"\n",
"import wandb\n",
"table = wandb.Table(columns=[\"prompt\", \"scale\", \"face\", \"gen\"])\n",
"\n",
"def rtn_face_get(self, img, face):\n",
" aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0])\n",
" #print(cv2.imwrite(\"aimg.png\", aimg))\n",
" face.embedding = self.get_feat(aimg).flatten()\n",
" face.crop_face = aimg\n",
" return face.embedding\n",
"\n",
"ArcFaceONNX.get = rtn_face_get\n",
"image = cv2.imread(\"assets/images/woman.png\")\n",
"faces = app.get(image)\n",
"faceid_embeds = faces[0].normed_embedding\n",
"faceid_embeds = torch.from_numpy(faceid_embeds).unsqueeze(0)\n",
"face_image = faces[0].crop_face"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "504c9305",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c307c379534c4a82a72addc7a2e4af33",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/30 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plus_ip_model = IPAdapterFaceIDPlus(copy.deepcopy(pipe), image_encoder_path, plus_ip_ckpt, device)\n",
"images = plus_ip_model.generate(\n",
" prompt=prompt,\n",
" negative_prompt=negative_prompt,\n",
" face_image=face_image,\n",
" faceid_embeds=faceid_embeds,\n",
" shortcut=v2,\n",
" s_scale=1,\n",
" num_samples=1,\n",
" width=512, height=768,\n",
" num_inference_steps=30, seed=2023\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e5aeccf0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([4, 768, 512])\n"
]
}
],
"source": [
"attn_maps = get_net_attn_map((768, 512))\n",
"print(attn_maps.shape)\n",
"attn_hot = attnmaps2images(attn_maps)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2838572",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"#axes[0].imshow(attn_hot[0], cmap='gray')\n",
"display_images = [cv2.cvtColor(cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)] + attn_hot + [images[0]]\n",
"fig, axes = plt.subplots(1, len(display_images), figsize=(12, 4))\n",
"for axe, image in zip(axes, display_images):\n",
" axe.imshow(image, cmap='gray')\n",
" axe.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fead1786",
"metadata": {},
"outputs": [],
"source": [
"ip_model = IPAdapterFaceID(copy.deepcopy(pipe), ip_ckpt, device)\n",
"images = ip_model.generate(\n",
" prompt=prompt, negative_prompt=negative_prompt,\n",
" faceid_embeds=faceid_embeds,\n",
" num_samples=1,\n",
" width=512, height=768,\n",
" num_inference_steps=30, seed=2023\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "124b3672",
"metadata": {},
"outputs": [],
"source": [
"attn_maps = get_net_attn_map((768, 512))\n",
"print(attn_maps.shape)\n",
"attn_hot = attnmaps2images(attn_maps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f409acde",
"metadata": {},
"outputs": [],
"source": [
"\n",
"display_images = [cv2.cvtColor(cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)] + attn_hot + [images[0]]\n",
"fig, axes = plt.subplots(1, len(display_images), figsize=(12, 4))\n",
"for axe, image in zip(axes, display_images):\n",
" axe.imshow(image, cmap='gray')\n",
" axe.axis('off')\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}