add visual
parent
9d8960cbe8
commit
06d94b2535
|
|
@ -183,6 +183,7 @@ class LoRAIPAttnProcessor(nn.Module):
|
|||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
self.attn_map = ip_attention_probs
|
||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||
|
||||
|
|
@ -201,4 +202,4 @@ class LoRAIPAttnProcessor(nn.Module):
|
|||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
return hidden_states
|
||||
|
|
|
|||
|
|
@ -1,5 +1,82 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
attn_maps = {}
|
||||
def hook_fn(name):
|
||||
def forward_hook(module, input, output):
|
||||
if hasattr(module.processor, "attn_map"):
|
||||
attn_maps[name] = module.processor.attn_map
|
||||
del module.processor.attn_map
|
||||
|
||||
return forward_hook
|
||||
|
||||
def register_cross_attention_hook(unet):
|
||||
for name, module in unet.named_modules():
|
||||
if name.split('.')[-1].startswith('attn2'):
|
||||
module.register_forward_hook(hook_fn(name))
|
||||
|
||||
return unet
|
||||
|
||||
def upscale(attn_map, target_size):
|
||||
attn_map = torch.mean(attn_map, dim=0)
|
||||
attn_map = attn_map.permute(1,0)
|
||||
temp_size = None
|
||||
|
||||
for i in range(0,5):
|
||||
scale = 2 ** i
|
||||
if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
|
||||
temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
|
||||
break
|
||||
|
||||
assert temp_size is not None, "temp_size cannot is None"
|
||||
|
||||
attn_map = attn_map.view(attn_map.shape[0], *temp_size)
|
||||
|
||||
attn_map = F.interpolate(
|
||||
attn_map.unsqueeze(0).to(dtype=torch.float32),
|
||||
size=target_size,
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)[0]
|
||||
|
||||
attn_map = torch.softmax(attn_map, dim=0)
|
||||
return attn_map
|
||||
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
|
||||
|
||||
idx = 0 if instance_or_negative else 1
|
||||
net_attn_maps = []
|
||||
|
||||
for name, attn_map in attn_maps.items():
|
||||
attn_map = attn_map.cpu() if detach else attn_map
|
||||
attn_map = torch.chunk(attn_map, batch_size)[idx] # (20, 32*32, 77) -> (10, 32*32, 77) # negative & positive CFG
|
||||
|
||||
attn_map = upscale(attn_map, image_size) # (10,32*32,77) -> (77,64*64)
|
||||
net_attn_maps.append(attn_map) # (10,32*32,77) -> (77,64*64)
|
||||
|
||||
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
|
||||
|
||||
return net_attn_maps
|
||||
|
||||
def attnmaps2images(net_attn_maps):
|
||||
|
||||
#total_attn_scores = 0
|
||||
images = []
|
||||
|
||||
for attn_map in net_attn_maps:
|
||||
attn_map = attn_map.cpu().numpy()
|
||||
#total_attn_scores += attn_map.mean().item()
|
||||
|
||||
normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
|
||||
normalized_attn_map = normalized_attn_map.astype(np.uint8)
|
||||
#print("norm: ", normalized_attn_map.shape)
|
||||
image = Image.fromarray(normalized_attn_map)
|
||||
|
||||
#image = fix_save_attn_map(attn_map)
|
||||
images.append(image)
|
||||
|
||||
#print(total_attn_scores)
|
||||
return images
|
||||
def is_torch2_available():
|
||||
return hasattr(F, "scaled_dot_product_attention")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,310 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "39f9cc7a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/root/anaconda3/lib/python3.9/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
|
||||
" _torch_pytree._register_pytree_node(\n",
|
||||
"/root/anaconda3/lib/python3.9/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
|
||||
" _torch_pytree._register_pytree_node(\n",
|
||||
"/root/anaconda3/lib/python3.9/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
|
||||
" torch.utils._pytree._register_pytree_node(\n",
|
||||
"/root/anaconda3/lib/python3.9/site-packages/scipy/__init__.py:155: UserWarning: A NumPy version >=1.18.5 and <1.25.0 is required for this version of SciPy (detected version 1.26.3\n",
|
||||
" warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL\n",
|
||||
"from PIL import Image\n",
|
||||
"import copy\n",
|
||||
"\n",
|
||||
"from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDPlus, IPAdapterFaceID\n",
|
||||
"from insightface.app import FaceAnalysis\n",
|
||||
"from insightface.model_zoo.arcface_onnx import ArcFaceONNX\n",
|
||||
"from insightface.utils import face_align\n",
|
||||
"from numpy.linalg import norm as l2norm\n",
|
||||
"import cv2\n",
|
||||
"from ip_adapter.utils import register_cross_attention_hook, get_net_attn_map, attnmaps2images"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "0d290971",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/root/anaconda3/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:69: UserWarning: Specified provider 'CUDAExecutionProvider' is not in available provider names.Available providers: 'AzureExecutionProvider, CPUExecutionProvider'\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n",
|
||||
"find model: /root/.insightface/models/buffalo_l/1k3d68.onnx landmark_3d_68 ['None', 3, 192, 192] 0.0 1.0\n",
|
||||
"Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n",
|
||||
"find model: /root/.insightface/models/buffalo_l/2d106det.onnx landmark_2d_106 ['None', 3, 192, 192] 0.0 1.0\n",
|
||||
"Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n",
|
||||
"find model: /root/.insightface/models/buffalo_l/det_10g.onnx detection [1, 3, '?', '?'] 127.5 128.0\n",
|
||||
"Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n",
|
||||
"find model: /root/.insightface/models/buffalo_l/genderage.onnx genderage ['None', 3, 96, 96] 0.0 1.0\n",
|
||||
"Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}\n",
|
||||
"find model: /root/.insightface/models/buffalo_l/w600k_r50.onnx recognition ['None', 3, 112, 112] 127.5 127.5\n",
|
||||
"set det-size: (640, 640)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"app = FaceAnalysis(name=\"buffalo_l\", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])\n",
|
||||
"app.prepare(ctx_id=0, det_size=(640, 640))\n",
|
||||
"\n",
|
||||
"v2 = False\n",
|
||||
"base_model_path = \"/dfs/comicai/zhiyuan.shi/models/SG161222/Realistic_Vision_V4.0_noVAE\"\n",
|
||||
"vae_model_path = \"/dfs/comicai/zhiyuan.shi/models/sd-vae-ft-mse\"\n",
|
||||
"image_encoder_path = \"/dfs/comicai/zhengbing.yao/models/IP-Adapter/models/image_encoder/\"\n",
|
||||
"plus_ip_ckpt = \"/dfs/comicai/zhiyuan.shi/models/IP-Adapter-FaceID/ip-adapter-faceid-plusv2_sd15.bin\"\n",
|
||||
"ip_ckpt = \"/dfs/comicai/zhiyuan.shi/models/IP-Adapter-FaceID/ip-adapter-faceid_sd15.bin\"\n",
|
||||
"device = \"cuda\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "f20eae92",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n",
|
||||
"```\n",
|
||||
"pip install accelerate\n",
|
||||
"```\n",
|
||||
".\n",
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: \n",
|
||||
"```\n",
|
||||
"pip install accelerate\n",
|
||||
"```\n",
|
||||
".\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "0e813051f12a4679aef02732b1baacf9",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading pipeline components...: 0%| | 0/5 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"noise_scheduler = DDIMScheduler(\n",
|
||||
" num_train_timesteps=1000,\n",
|
||||
" beta_start=0.00085,\n",
|
||||
" beta_end=0.012,\n",
|
||||
" beta_schedule=\"scaled_linear\",\n",
|
||||
" clip_sample=False,\n",
|
||||
" set_alpha_to_one=False,\n",
|
||||
" steps_offset=1,\n",
|
||||
")\n",
|
||||
"vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)\n",
|
||||
"pipe = StableDiffusionPipeline.from_pretrained(\n",
|
||||
" base_model_path,\n",
|
||||
" torch_dtype=torch.float16,\n",
|
||||
" scheduler=noise_scheduler,\n",
|
||||
" vae=vae,\n",
|
||||
" feature_extractor=None,\n",
|
||||
" safety_checker=None\n",
|
||||
")\n",
|
||||
"pipe.unet = register_cross_attention_hook(pipe.unet)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "40372f70",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/root/anaconda3/lib/python3.9/site-packages/insightface/utils/transform.py:68: FutureWarning: `rcond` parameter will change to the default of machine precision times ``max(M, N)`` where M and N are the input matrix dimensions.\n",
|
||||
"To use the future default and silence this warning we advise to pass `rcond=None`, to keep using the old, explicitly pass `rcond=-1`.\n",
|
||||
" P = np.linalg.lstsq(X_homo, Y)[0].T # Affine matrix. 3 x 4\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# generate image\n",
|
||||
"prompt = \"photo of a woman in red dress in a garden, white hair, happy\"\n",
|
||||
"negative_prompt = \"monochrome, lowres, bad anatomy, worst quality, low quality, blurry\"\n",
|
||||
"\n",
|
||||
"import wandb\n",
|
||||
"table = wandb.Table(columns=[\"prompt\", \"scale\", \"face\", \"gen\"])\n",
|
||||
"\n",
|
||||
"def rtn_face_get(self, img, face):\n",
|
||||
" aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0])\n",
|
||||
" #print(cv2.imwrite(\"aimg.png\", aimg))\n",
|
||||
" face.embedding = self.get_feat(aimg).flatten()\n",
|
||||
" face.crop_face = aimg\n",
|
||||
" return face.embedding\n",
|
||||
"\n",
|
||||
"ArcFaceONNX.get = rtn_face_get\n",
|
||||
"image = cv2.imread(\"assets/images/woman.png\")\n",
|
||||
"faces = app.get(image)\n",
|
||||
"faceid_embeds = faces[0].normed_embedding\n",
|
||||
"faceid_embeds = torch.from_numpy(faceid_embeds).unsqueeze(0)\n",
|
||||
"face_image = faces[0].crop_face"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "504c9305",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c307c379534c4a82a72addc7a2e4af33",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/30 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"plus_ip_model = IPAdapterFaceIDPlus(copy.deepcopy(pipe), image_encoder_path, plus_ip_ckpt, device)\n",
|
||||
"images = plus_ip_model.generate(\n",
|
||||
" prompt=prompt,\n",
|
||||
" negative_prompt=negative_prompt,\n",
|
||||
" face_image=face_image,\n",
|
||||
" faceid_embeds=faceid_embeds,\n",
|
||||
" shortcut=v2,\n",
|
||||
" s_scale=1,\n",
|
||||
" num_samples=1,\n",
|
||||
" width=512, height=768,\n",
|
||||
" num_inference_steps=30, seed=2023\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "e5aeccf0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([4, 768, 512])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"attn_maps = get_net_attn_map((768, 512))\n",
|
||||
"print(attn_maps.shape)\n",
|
||||
"attn_hot = attnmaps2images(attn_maps)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f2838572",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"#axes[0].imshow(attn_hot[0], cmap='gray')\n",
|
||||
"display_images = [cv2.cvtColor(cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)] + attn_hot + [images[0]]\n",
|
||||
"fig, axes = plt.subplots(1, len(display_images), figsize=(12, 4))\n",
|
||||
"for axe, image in zip(axes, display_images):\n",
|
||||
" axe.imshow(image, cmap='gray')\n",
|
||||
" axe.axis('off')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fead1786",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ip_model = IPAdapterFaceID(copy.deepcopy(pipe), ip_ckpt, device)\n",
|
||||
"images = ip_model.generate(\n",
|
||||
" prompt=prompt, negative_prompt=negative_prompt,\n",
|
||||
" faceid_embeds=faceid_embeds,\n",
|
||||
" num_samples=1,\n",
|
||||
" width=512, height=768,\n",
|
||||
" num_inference_steps=30, seed=2023\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "124b3672",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"attn_maps = get_net_attn_map((768, 512))\n",
|
||||
"print(attn_maps.shape)\n",
|
||||
"attn_hot = attnmaps2images(attn_maps)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f409acde",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"display_images = [cv2.cvtColor(cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)] + attn_hot + [images[0]]\n",
|
||||
"fig, axes = plt.subplots(1, len(display_images), figsize=(12, 4))\n",
|
||||
"for axe, image in zip(axes, display_images):\n",
|
||||
" axe.imshow(image, cmap='gray')\n",
|
||||
" axe.axis('off')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Loading…
Reference in New Issue