support other attn optimizations

pull/157/head
pkuliyi2015 2023-04-15 13:04:45 +00:00
parent e46cf8eb4a
commit 62280878e5
4 changed files with 312 additions and 115 deletions

6
.gitignore vendored
View File

@ -5,3 +5,9 @@ __pycache__/
# settings
region_configs/
# test images
deflicker/input_frames/*
# test features
deflicker/*

View File

@ -26,7 +26,7 @@
# - Allows for super large resolutions (2k~8k) for both txt2img and img2img.
# - The merged output is completely seamless without any post-processing.
# - Training free. No need to train a new model, and you can control the
# text prompt for each tile.
# text prompt for specific regions.
#
# Drawbacks:
# - Depending on your parameter settings, the process can be very slow,
@ -34,17 +34,23 @@
# - The gradient calculation is not compatible with this hack. It
# will break any backward() or torch.autograd.grad() that passes UNet.
#
# How it works (insanely simple!)
# 1) The latent image x_t is split into tiles
# 2) The tiles are denoised by original sampler to get x_t-1
# 3) The tiles are added together, but divided by how many times each pixel
# is added.
# How it works:
# 1. The latent image is split into tiles.
# 2. In MultiDiffusion:
# 1. The UNet predicts the noise of each tile.
# 2. The tiles are denoised by the original sampler for one time step.
# 3. The tiles are added together but divided by how many times each pixel is added.
# 3. In Mixture of Diffusers:
# 1. The UNet predicts the noise of each tile
# 2. All noises are fused with a gaussian weight mask.
# 3. The denoiser denoises the whole image for one time step using fused noises.
# 4. Repeat 2-3 until all timesteps are completed.
#
# Enjoy!
#
# @author: LI YI @ Nanyang Technological University - Singapore
# @date: 2023-03-03
# @license: MIT License
# @license: CC BY-NC-SA 4.0
#
# Please give me a star if you like this project!
#
@ -293,7 +299,6 @@ class Script(scripts.Script):
''' upscale '''
if is_img2img: # img2img
upscaler_name = [x.name for x in shared.sd_upscalers].index(upscaler_index)
init_img = p.init_images[0]
init_img = images.flatten(init_img, opts.img2img_background_color)
upscaler = shared.sd_upscalers[upscaler_name]
@ -302,13 +307,12 @@ class Script(scripts.Script):
image = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path)
p.extra_generation_params["Tiled Diffusion upscaler"] = upscaler.name
p.extra_generation_params["Tiled Diffusion scale factor"] = scale_factor
else:
image = init_img
# For webui folder based batch processing, the length of init_images is not 1
# We need to replace all images with the upsampled one
for i in range(len(p.init_images)):
p.init_images[i] = image
else:
image = init_img
if keep_input_size:
p.width = image.width

View File

@ -1,7 +1,7 @@
'''
# ------------------------------------------------------------------------
#
# Ultimate VAE Tile Optimization
# Tiled VAE
#
# Introducing a revolutionary new optimization designed to make
# the VAE work with giant images on limited VRAM!
@ -18,41 +18,31 @@
# - The merged output is completely seamless without any post-processing.
#
# Drawbacks:
# - Giant RAM needed. To store the intermediate results for a 4096x4096
# images, you need 32 GB RAM it consumes ~20GB); for 8192x8192
# you need 128 GB RAM machine (it consumes ~100 GB)
# - NaNs always appear in for 8k images when you use fp16 (half) VAE
# You must use --no-half-vae to disable half VAE for that giant image.
# - Slow speed. With default tile size, it takes around 50/200 seconds
# to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode
# a 8192x8192 image. (The speed is limited by both the GPU and the CPU.)
# - The gradient calculation is not compatible with this hack. It
# will break any backward() or torch.autograd.grad() that passes VAE.
# (But you can still use the VAE to generate training data.)
#
# How it works:
# 1) The image is split into tiles.
# - To ensure perfect results, each tile is padded with 32 pixels
# on each side.
# - Then the conv2d/silu/upsample/downsample can produce identical
# results to the original image without splitting.
# 2) The original forward is decomposed into a task queue and a task worker.
# - The task queue is a list of functions that will be executed in order.
# - The task worker is a loop that executes the tasks in the queue.
# 3) The task queue is executed for each tile.
# - Current tile is sent to GPU.
# - local operations are directly executed.
# - Group norm calculation is temporarily suspended until the mean
# and var of all tiles are calculated.
# - The residual is pre-calculated and stored and addded back later.
# - When need to go to the next tile, the current tile is send to cpu.
# 4) After all tiles are processed, tiles are merged on cpu and return.
# 1. The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder.
# 2. When Fast Mode is disabled:
# 1. The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile.
# 2. When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile.
# 3. After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues.
# 4. A zigzag execution order is used to reduce unnecessary data transfer.
# 3. When Fast Mode is enabled:
# 1. The original input is downsampled and passed to a separate task queue.
# 2. Its group norm parameters are recorded and used by all tiles' task queues.
# 3. Each tile is separately processed without any RAM-VRAM data transfer.
# 4. After all tiles are processed, tiles are written to a result buffer and returned.
# Encoder color fix = only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode.
#
# Enjoy!
#
# @author: LI YI @ Nanyang Technological University - Singapore
# @date: 2023-03-02
# @license: MIT License
# @Author: LI YI @ Nanyang Technological University - Singapore
# @Date: 2023-03-02
# @License: CC BY-NC-SA 4.0
#
# Please give me a star if you like this project!
#
@ -67,7 +57,6 @@ from tqdm import tqdm
import torch
import torch.version
import torch.nn.functional as F
from einops import rearrange
import gradio as gr
import modules.scripts as scripts
@ -75,11 +64,7 @@ import modules.devices as devices
from modules.shared import state
from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock
try:
import xformers
import xformers.ops
except ImportError:
pass
from tile_utils.attn import get_attn_func
def get_recommend_encoder_tile_size():
@ -123,78 +108,13 @@ def inplace_nonlinearity(x):
# Test: fix for Nans
return F.silu(x, inplace=True)
# extracted from ldm.modules.diffusionmodules.model
def attn_forward(self, h_):
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h*w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h*w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h*w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return h_
def xformer_attn_forward(self, h_):
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op)
out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
out = self.proj_out(out)
return out
def attn2task(task_queue, net):
if isinstance(net, AttnBlock):
attn_forward = get_attn_func(isinstance(net, MemoryEfficientAttnBlock))
task_queue.append(('store_res', lambda x: x))
task_queue.append(('pre_norm', net.norm))
task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
task_queue.append(['add_res', None])
elif isinstance(net, MemoryEfficientAttnBlock):
task_queue.append(('store_res', lambda x: x))
task_queue.append(('pre_norm', net.norm))
task_queue.append(
('attn', lambda x, net=net: xformer_attn_forward(net, x)))
task_queue.append(['add_res', None])
def resblock2task(queue, block):

267
tile_utils/attn.py Normal file
View File

@ -0,0 +1,267 @@
'''
This file is modified from the sd_hijack_optimizations.py to remove the residual and norm part,
So that the Tiled VAE can support other types of attention.
'''
import math
import torch
from modules import shared, devices, errors
from modules.shared import cmd_opts
from einops import rearrange
from modules.sub_quadratic_attention import efficient_dot_product_attention
from modules.sd_hijack_optimizations import get_available_vram
try:
import xformers
import xformers.ops
except ImportError:
pass
def get_attn_func(memory_efficient):
attn_forward_method = xformer_attn_forward if memory_efficient else attn_forward
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
attn_forward_method = xformers_attnblock_forward
elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
attn_forward_method = sdp_no_mem_attnblock_forward
elif cmd_opts.opt_sdp_attention and can_use_sdp:
attn_forward_method = sdp_attnblock_forward
elif cmd_opts.opt_sub_quad_attention:
attn_forward_method = sub_quad_attnblock_forward
elif cmd_opts.opt_split_attention_v1:
pass
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
pass
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
attn_forward_method = cross_attention_attnblock_forward
return attn_forward_method
# The following functions are all copied from modules.sd_hijack_optimizations
# However, the residual & normalization are removed and computed later.
def attn_forward(self, h_):
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h*w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h*w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h*w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return h_
def xformers_attnblock_forward(self, h_):
try:
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
out = out.to(dtype)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return out
except NotImplementedError:
return cross_attention_attnblock_forward(self, h_)
def cross_attention_attnblock_forward(self, h_):
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h*w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h*w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
mem_free_total = get_available_vram()
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
return h3
def sdp_no_mem_attnblock_forward(self, x):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return sdp_attnblock_forward(self, x)
def sdp_attnblock_forward(self, h_):
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
out = out.to(dtype)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return out
def sub_quad_attnblock_forward(self, h_):
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return out
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
bytes_per_token = torch.finfo(q.dtype).bits//8
batch_x_heads, q_tokens, _ = q.shape
_, k_tokens, _ = k.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
if chunk_threshold is None:
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
elif chunk_threshold == 0:
chunk_threshold_bytes = None
else:
chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
elif kv_chunk_size_min == 0:
kv_chunk_size_min = None
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
with devices.without_autocast(disable=q.dtype == v.dtype):
return efficient_dot_product_attention(
q,
k,
v,
query_chunk_size=q_chunk_size,
kv_chunk_size=kv_chunk_size,
kv_chunk_size_min = kv_chunk_size_min,
use_checkpoint=use_checkpoint,
)
def get_xformers_flash_attention_op(q, k, v):
if not shared.cmd_opts.xformers_flash_attention:
return None
try:
flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
fw, bw = flash_attention_op
if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
return flash_attention_op
except Exception as e:
errors.display_once(e, "enabling flash attention")
return None
def xformer_attn_forward(self, h_):
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op)
out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
out = self.proj_out(out)
return out