diff --git a/.gitignore b/.gitignore index 48a6db7..5a98b2c 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,9 @@ __pycache__/ # settings region_configs/ + +# test images +deflicker/input_frames/* + +# test features +deflicker/* diff --git a/scripts/tilediffusion.py b/scripts/tilediffusion.py index 6b9d263..624e063 100644 --- a/scripts/tilediffusion.py +++ b/scripts/tilediffusion.py @@ -26,7 +26,7 @@ # - Allows for super large resolutions (2k~8k) for both txt2img and img2img. # - The merged output is completely seamless without any post-processing. # - Training free. No need to train a new model, and you can control the -# text prompt for each tile. +# text prompt for specific regions. # # Drawbacks: # - Depending on your parameter settings, the process can be very slow, @@ -34,17 +34,23 @@ # - The gradient calculation is not compatible with this hack. It # will break any backward() or torch.autograd.grad() that passes UNet. # -# How it works (insanely simple!) -# 1) The latent image x_t is split into tiles -# 2) The tiles are denoised by original sampler to get x_t-1 -# 3) The tiles are added together, but divided by how many times each pixel -# is added. +# How it works: +# 1. The latent image is split into tiles. +# 2. In MultiDiffusion: +# 1. The UNet predicts the noise of each tile. +# 2. The tiles are denoised by the original sampler for one time step. +# 3. The tiles are added together but divided by how many times each pixel is added. +# 3. In Mixture of Diffusers: +# 1. The UNet predicts the noise of each tile +# 2. All noises are fused with a gaussian weight mask. +# 3. The denoiser denoises the whole image for one time step using fused noises. +# 4. Repeat 2-3 until all timesteps are completed. # # Enjoy! # # @author: LI YI @ Nanyang Technological University - Singapore # @date: 2023-03-03 -# @license: MIT License +# @license: CC BY-NC-SA 4.0 # # Please give me a star if you like this project! # @@ -293,7 +299,6 @@ class Script(scripts.Script): ''' upscale ''' if is_img2img: # img2img upscaler_name = [x.name for x in shared.sd_upscalers].index(upscaler_index) - init_img = p.init_images[0] init_img = images.flatten(init_img, opts.img2img_background_color) upscaler = shared.sd_upscalers[upscaler_name] @@ -302,14 +307,13 @@ class Script(scripts.Script): image = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path) p.extra_generation_params["Tiled Diffusion upscaler"] = upscaler.name p.extra_generation_params["Tiled Diffusion scale factor"] = scale_factor + # For webui folder based batch processing, the length of init_images is not 1 + # We need to replace all images with the upsampled one + for i in range(len(p.init_images)): + p.init_images[i] = image else: image = init_img - # For webui folder based batch processing, the length of init_images is not 1 - # We need to replace all images with the upsampled one - for i in range(len(p.init_images)): - p.init_images[i] = image - if keep_input_size: p.width = image.width p.height = image.height diff --git a/scripts/vae_optimize.py b/scripts/vae_optimize.py index 27f549f..98890c5 100644 --- a/scripts/vae_optimize.py +++ b/scripts/vae_optimize.py @@ -1,7 +1,7 @@ ''' # ------------------------------------------------------------------------ # -# Ultimate VAE Tile Optimization +# Tiled VAE # # Introducing a revolutionary new optimization designed to make # the VAE work with giant images on limited VRAM! @@ -18,41 +18,31 @@ # - The merged output is completely seamless without any post-processing. # # Drawbacks: -# - Giant RAM needed. To store the intermediate results for a 4096x4096 -# images, you need 32 GB RAM it consumes ~20GB); for 8192x8192 -# you need 128 GB RAM machine (it consumes ~100 GB) # - NaNs always appear in for 8k images when you use fp16 (half) VAE # You must use --no-half-vae to disable half VAE for that giant image. -# - Slow speed. With default tile size, it takes around 50/200 seconds -# to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode -# a 8192x8192 image. (The speed is limited by both the GPU and the CPU.) # - The gradient calculation is not compatible with this hack. It # will break any backward() or torch.autograd.grad() that passes VAE. # (But you can still use the VAE to generate training data.) # # How it works: -# 1) The image is split into tiles. -# - To ensure perfect results, each tile is padded with 32 pixels -# on each side. -# - Then the conv2d/silu/upsample/downsample can produce identical -# results to the original image without splitting. -# 2) The original forward is decomposed into a task queue and a task worker. -# - The task queue is a list of functions that will be executed in order. -# - The task worker is a loop that executes the tasks in the queue. -# 3) The task queue is executed for each tile. -# - Current tile is sent to GPU. -# - local operations are directly executed. -# - Group norm calculation is temporarily suspended until the mean -# and var of all tiles are calculated. -# - The residual is pre-calculated and stored and addded back later. -# - When need to go to the next tile, the current tile is send to cpu. -# 4) After all tiles are processed, tiles are merged on cpu and return. +# 1. The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder. +# 2. When Fast Mode is disabled: +# 1. The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile. +# 2. When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile. +# 3. After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues. +# 4. A zigzag execution order is used to reduce unnecessary data transfer. +# 3. When Fast Mode is enabled: +# 1. The original input is downsampled and passed to a separate task queue. +# 2. Its group norm parameters are recorded and used by all tiles' task queues. +# 3. Each tile is separately processed without any RAM-VRAM data transfer. +# 4. After all tiles are processed, tiles are written to a result buffer and returned. +# Encoder color fix = only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode. # # Enjoy! # -# @author: LI YI @ Nanyang Technological University - Singapore -# @date: 2023-03-02 -# @license: MIT License +# @Author: LI YI @ Nanyang Technological University - Singapore +# @Date: 2023-03-02 +# @License: CC BY-NC-SA 4.0 # # Please give me a star if you like this project! # @@ -67,7 +57,6 @@ from tqdm import tqdm import torch import torch.version import torch.nn.functional as F -from einops import rearrange import gradio as gr import modules.scripts as scripts @@ -75,11 +64,7 @@ import modules.devices as devices from modules.shared import state from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock -try: - import xformers - import xformers.ops -except ImportError: - pass +from tile_utils.attn import get_attn_func def get_recommend_encoder_tile_size(): @@ -123,78 +108,13 @@ def inplace_nonlinearity(x): # Test: fix for Nans return F.silu(x, inplace=True) -# extracted from ldm.modules.diffusionmodules.model - - -def attn_forward(self, h_): - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h*w) - q = q.permute(0, 2, 1) # b,hw,c - k = k.reshape(b, c, h*w) # b,c,hw - w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h*w) - w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = torch.bmm(v, w_) - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return h_ - - -def xformer_attn_forward(self, h_): - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) - - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), - (q, k, v), - ) - out = xformers.ops.memory_efficient_attention( - q, k, v, attn_bias=None, op=self.attention_op) - - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) - out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) - out = self.proj_out(out) - return out - def attn2task(task_queue, net): - if isinstance(net, AttnBlock): - task_queue.append(('store_res', lambda x: x)) - task_queue.append(('pre_norm', net.norm)) - task_queue.append(('attn', lambda x, net=net: attn_forward(net, x))) - task_queue.append(['add_res', None]) - elif isinstance(net, MemoryEfficientAttnBlock): - task_queue.append(('store_res', lambda x: x)) - task_queue.append(('pre_norm', net.norm)) - task_queue.append( - ('attn', lambda x, net=net: xformer_attn_forward(net, x))) - task_queue.append(['add_res', None]) + attn_forward = get_attn_func(isinstance(net, MemoryEfficientAttnBlock)) + task_queue.append(('store_res', lambda x: x)) + task_queue.append(('pre_norm', net.norm)) + task_queue.append(('attn', lambda x, net=net: attn_forward(net, x))) + task_queue.append(['add_res', None]) def resblock2task(queue, block): diff --git a/tile_utils/attn.py b/tile_utils/attn.py new file mode 100644 index 0000000..2196e86 --- /dev/null +++ b/tile_utils/attn.py @@ -0,0 +1,267 @@ +''' + This file is modified from the sd_hijack_optimizations.py to remove the residual and norm part, + So that the Tiled VAE can support other types of attention. +''' +import math +import torch + +from modules import shared, devices, errors +from modules.shared import cmd_opts +from einops import rearrange +from modules.sub_quadratic_attention import efficient_dot_product_attention +from modules.sd_hijack_optimizations import get_available_vram + + +try: + import xformers + import xformers.ops +except ImportError: + pass + + +def get_attn_func(memory_efficient): + attn_forward_method = xformer_attn_forward if memory_efficient else attn_forward + can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp + + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): + attn_forward_method = xformers_attnblock_forward + elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp: + attn_forward_method = sdp_no_mem_attnblock_forward + elif cmd_opts.opt_sdp_attention and can_use_sdp: + attn_forward_method = sdp_attnblock_forward + elif cmd_opts.opt_sub_quad_attention: + attn_forward_method = sub_quad_attnblock_forward + elif cmd_opts.opt_split_attention_v1: + pass + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()): + pass + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): + attn_forward_method = cross_attention_attnblock_forward + + return attn_forward_method + +# The following functions are all copied from modules.sd_hijack_optimizations +# However, the residual & normalization are removed and computed later. + + +def attn_forward(self, h_): + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h*w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h*w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return h_ + + +def xformers_attnblock_forward(self, h_): + try: + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) + out = out.to(dtype) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return out + except NotImplementedError: + return cross_attention_attnblock_forward(self, h_) + + +def cross_attention_attnblock_forward(self, h_): + q1 = self.q(h_) + k1 = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + mem_free_total = get_available_vram() + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) + del w2 + + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 + + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 + + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + return h3 + + +def sdp_no_mem_attnblock_forward(self, x): + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): + return sdp_attnblock_forward(self, x) + + +def sdp_attnblock_forward(self, h_): + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) + out = out.to(dtype) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return out + +def sub_quad_attnblock_forward(self, h_): + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return out + + +def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): + bytes_per_token = torch.finfo(q.dtype).bits//8 + batch_x_heads, q_tokens, _ = q.shape + _, k_tokens, _ = k.shape + qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens + + if chunk_threshold is None: + chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + elif chunk_threshold == 0: + chunk_threshold_bytes = None + else: + chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram()) + + if kv_chunk_size_min is None and chunk_threshold_bytes is not None: + kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) + elif kv_chunk_size_min == 0: + kv_chunk_size_min = None + + if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: + # the big matmul fits into our memory limit; do everything in 1 chunk, + # i.e. send it down the unchunked fast-path + query_chunk_size = q_tokens + kv_chunk_size = k_tokens + + with devices.without_autocast(disable=q.dtype == v.dtype): + return efficient_dot_product_attention( + q, + k, + v, + query_chunk_size=q_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min = kv_chunk_size_min, + use_checkpoint=use_checkpoint, + ) + + +def get_xformers_flash_attention_op(q, k, v): + if not shared.cmd_opts.xformers_flash_attention: + return None + + try: + flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = flash_attention_op + if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): + return flash_attention_op + except Exception as e: + errors.display_once(e, "enabling flash attention") + + return None + + +def xformer_attn_forward(self, h_): + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return out \ No newline at end of file