support other attn optimizations
parent
e46cf8eb4a
commit
62280878e5
|
|
@ -5,3 +5,9 @@ __pycache__/
|
|||
|
||||
# settings
|
||||
region_configs/
|
||||
|
||||
# test images
|
||||
deflicker/input_frames/*
|
||||
|
||||
# test features
|
||||
deflicker/*
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@
|
|||
# - Allows for super large resolutions (2k~8k) for both txt2img and img2img.
|
||||
# - The merged output is completely seamless without any post-processing.
|
||||
# - Training free. No need to train a new model, and you can control the
|
||||
# text prompt for each tile.
|
||||
# text prompt for specific regions.
|
||||
#
|
||||
# Drawbacks:
|
||||
# - Depending on your parameter settings, the process can be very slow,
|
||||
|
|
@ -34,17 +34,23 @@
|
|||
# - The gradient calculation is not compatible with this hack. It
|
||||
# will break any backward() or torch.autograd.grad() that passes UNet.
|
||||
#
|
||||
# How it works (insanely simple!)
|
||||
# 1) The latent image x_t is split into tiles
|
||||
# 2) The tiles are denoised by original sampler to get x_t-1
|
||||
# 3) The tiles are added together, but divided by how many times each pixel
|
||||
# is added.
|
||||
# How it works:
|
||||
# 1. The latent image is split into tiles.
|
||||
# 2. In MultiDiffusion:
|
||||
# 1. The UNet predicts the noise of each tile.
|
||||
# 2. The tiles are denoised by the original sampler for one time step.
|
||||
# 3. The tiles are added together but divided by how many times each pixel is added.
|
||||
# 3. In Mixture of Diffusers:
|
||||
# 1. The UNet predicts the noise of each tile
|
||||
# 2. All noises are fused with a gaussian weight mask.
|
||||
# 3. The denoiser denoises the whole image for one time step using fused noises.
|
||||
# 4. Repeat 2-3 until all timesteps are completed.
|
||||
#
|
||||
# Enjoy!
|
||||
#
|
||||
# @author: LI YI @ Nanyang Technological University - Singapore
|
||||
# @date: 2023-03-03
|
||||
# @license: MIT License
|
||||
# @license: CC BY-NC-SA 4.0
|
||||
#
|
||||
# Please give me a star if you like this project!
|
||||
#
|
||||
|
|
@ -293,7 +299,6 @@ class Script(scripts.Script):
|
|||
''' upscale '''
|
||||
if is_img2img: # img2img
|
||||
upscaler_name = [x.name for x in shared.sd_upscalers].index(upscaler_index)
|
||||
|
||||
init_img = p.init_images[0]
|
||||
init_img = images.flatten(init_img, opts.img2img_background_color)
|
||||
upscaler = shared.sd_upscalers[upscaler_name]
|
||||
|
|
@ -302,13 +307,12 @@ class Script(scripts.Script):
|
|||
image = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path)
|
||||
p.extra_generation_params["Tiled Diffusion upscaler"] = upscaler.name
|
||||
p.extra_generation_params["Tiled Diffusion scale factor"] = scale_factor
|
||||
else:
|
||||
image = init_img
|
||||
|
||||
# For webui folder based batch processing, the length of init_images is not 1
|
||||
# We need to replace all images with the upsampled one
|
||||
for i in range(len(p.init_images)):
|
||||
p.init_images[i] = image
|
||||
else:
|
||||
image = init_img
|
||||
|
||||
if keep_input_size:
|
||||
p.width = image.width
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
'''
|
||||
# ------------------------------------------------------------------------
|
||||
#
|
||||
# Ultimate VAE Tile Optimization
|
||||
# Tiled VAE
|
||||
#
|
||||
# Introducing a revolutionary new optimization designed to make
|
||||
# the VAE work with giant images on limited VRAM!
|
||||
|
|
@ -18,41 +18,31 @@
|
|||
# - The merged output is completely seamless without any post-processing.
|
||||
#
|
||||
# Drawbacks:
|
||||
# - Giant RAM needed. To store the intermediate results for a 4096x4096
|
||||
# images, you need 32 GB RAM it consumes ~20GB); for 8192x8192
|
||||
# you need 128 GB RAM machine (it consumes ~100 GB)
|
||||
# - NaNs always appear in for 8k images when you use fp16 (half) VAE
|
||||
# You must use --no-half-vae to disable half VAE for that giant image.
|
||||
# - Slow speed. With default tile size, it takes around 50/200 seconds
|
||||
# to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode
|
||||
# a 8192x8192 image. (The speed is limited by both the GPU and the CPU.)
|
||||
# - The gradient calculation is not compatible with this hack. It
|
||||
# will break any backward() or torch.autograd.grad() that passes VAE.
|
||||
# (But you can still use the VAE to generate training data.)
|
||||
#
|
||||
# How it works:
|
||||
# 1) The image is split into tiles.
|
||||
# - To ensure perfect results, each tile is padded with 32 pixels
|
||||
# on each side.
|
||||
# - Then the conv2d/silu/upsample/downsample can produce identical
|
||||
# results to the original image without splitting.
|
||||
# 2) The original forward is decomposed into a task queue and a task worker.
|
||||
# - The task queue is a list of functions that will be executed in order.
|
||||
# - The task worker is a loop that executes the tasks in the queue.
|
||||
# 3) The task queue is executed for each tile.
|
||||
# - Current tile is sent to GPU.
|
||||
# - local operations are directly executed.
|
||||
# - Group norm calculation is temporarily suspended until the mean
|
||||
# and var of all tiles are calculated.
|
||||
# - The residual is pre-calculated and stored and addded back later.
|
||||
# - When need to go to the next tile, the current tile is send to cpu.
|
||||
# 4) After all tiles are processed, tiles are merged on cpu and return.
|
||||
# 1. The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder.
|
||||
# 2. When Fast Mode is disabled:
|
||||
# 1. The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile.
|
||||
# 2. When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile.
|
||||
# 3. After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues.
|
||||
# 4. A zigzag execution order is used to reduce unnecessary data transfer.
|
||||
# 3. When Fast Mode is enabled:
|
||||
# 1. The original input is downsampled and passed to a separate task queue.
|
||||
# 2. Its group norm parameters are recorded and used by all tiles' task queues.
|
||||
# 3. Each tile is separately processed without any RAM-VRAM data transfer.
|
||||
# 4. After all tiles are processed, tiles are written to a result buffer and returned.
|
||||
# Encoder color fix = only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode.
|
||||
#
|
||||
# Enjoy!
|
||||
#
|
||||
# @author: LI YI @ Nanyang Technological University - Singapore
|
||||
# @date: 2023-03-02
|
||||
# @license: MIT License
|
||||
# @Author: LI YI @ Nanyang Technological University - Singapore
|
||||
# @Date: 2023-03-02
|
||||
# @License: CC BY-NC-SA 4.0
|
||||
#
|
||||
# Please give me a star if you like this project!
|
||||
#
|
||||
|
|
@ -67,7 +57,6 @@ from tqdm import tqdm
|
|||
import torch
|
||||
import torch.version
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import gradio as gr
|
||||
|
||||
import modules.scripts as scripts
|
||||
|
|
@ -75,11 +64,7 @@ import modules.devices as devices
|
|||
from modules.shared import state
|
||||
from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
pass
|
||||
from tile_utils.attn import get_attn_func
|
||||
|
||||
|
||||
def get_recommend_encoder_tile_size():
|
||||
|
|
@ -123,78 +108,13 @@ def inplace_nonlinearity(x):
|
|||
# Test: fix for Nans
|
||||
return F.silu(x, inplace=True)
|
||||
|
||||
# extracted from ldm.modules.diffusionmodules.model
|
||||
|
||||
|
||||
def attn_forward(self, h_):
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h*w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h*w) # b,c,hw
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h*w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return h_
|
||||
|
||||
|
||||
def xformer_attn_forward(self, h_):
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(B, t.shape[1], 1, C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * 1, t.shape[1], C)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(
|
||||
q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(B, 1, out.shape[1], C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B, out.shape[1], C)
|
||||
)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
||||
out = self.proj_out(out)
|
||||
return out
|
||||
|
||||
|
||||
def attn2task(task_queue, net):
|
||||
if isinstance(net, AttnBlock):
|
||||
attn_forward = get_attn_func(isinstance(net, MemoryEfficientAttnBlock))
|
||||
task_queue.append(('store_res', lambda x: x))
|
||||
task_queue.append(('pre_norm', net.norm))
|
||||
task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
|
||||
task_queue.append(['add_res', None])
|
||||
elif isinstance(net, MemoryEfficientAttnBlock):
|
||||
task_queue.append(('store_res', lambda x: x))
|
||||
task_queue.append(('pre_norm', net.norm))
|
||||
task_queue.append(
|
||||
('attn', lambda x, net=net: xformer_attn_forward(net, x)))
|
||||
task_queue.append(['add_res', None])
|
||||
|
||||
|
||||
def resblock2task(queue, block):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,267 @@
|
|||
'''
|
||||
This file is modified from the sd_hijack_optimizations.py to remove the residual and norm part,
|
||||
So that the Tiled VAE can support other types of attention.
|
||||
'''
|
||||
import math
|
||||
import torch
|
||||
|
||||
from modules import shared, devices, errors
|
||||
from modules.shared import cmd_opts
|
||||
from einops import rearrange
|
||||
from modules.sub_quadratic_attention import efficient_dot_product_attention
|
||||
from modules.sd_hijack_optimizations import get_available_vram
|
||||
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def get_attn_func(memory_efficient):
|
||||
attn_forward_method = xformer_attn_forward if memory_efficient else attn_forward
|
||||
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
|
||||
|
||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||
attn_forward_method = xformers_attnblock_forward
|
||||
elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
|
||||
attn_forward_method = sdp_no_mem_attnblock_forward
|
||||
elif cmd_opts.opt_sdp_attention and can_use_sdp:
|
||||
attn_forward_method = sdp_attnblock_forward
|
||||
elif cmd_opts.opt_sub_quad_attention:
|
||||
attn_forward_method = sub_quad_attnblock_forward
|
||||
elif cmd_opts.opt_split_attention_v1:
|
||||
pass
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
|
||||
pass
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||
attn_forward_method = cross_attention_attnblock_forward
|
||||
|
||||
return attn_forward_method
|
||||
|
||||
# The following functions are all copied from modules.sd_hijack_optimizations
|
||||
# However, the residual & normalization are removed and computed later.
|
||||
|
||||
|
||||
def attn_forward(self, h_):
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h*w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h*w) # b,c,hw
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h*w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return h_
|
||||
|
||||
|
||||
def xformers_attnblock_forward(self, h_):
|
||||
try:
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
|
||||
out = out.to(dtype)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||
out = self.proj_out(out)
|
||||
return out
|
||||
except NotImplementedError:
|
||||
return cross_attention_attnblock_forward(self, h_)
|
||||
|
||||
|
||||
def cross_attention_attnblock_forward(self, h_):
|
||||
q1 = self.q(h_)
|
||||
k1 = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q1.shape
|
||||
|
||||
q2 = q1.reshape(b, c, h*w)
|
||||
del q1
|
||||
|
||||
q = q2.permute(0, 2, 1) # b,hw,c
|
||||
del q2
|
||||
|
||||
k = k1.reshape(b, c, h*w) # b,c,hw
|
||||
del k1
|
||||
|
||||
h_ = torch.zeros_like(k, device=q.device)
|
||||
|
||||
mem_free_total = get_available_vram()
|
||||
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
|
||||
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w2 = w1 * (int(c)**(-0.5))
|
||||
del w1
|
||||
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
||||
del w2
|
||||
|
||||
# attend to values
|
||||
v1 = v.reshape(b, c, h*w)
|
||||
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
del w3
|
||||
|
||||
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
del v1, w4
|
||||
|
||||
h2 = h_.reshape(b, c, h, w)
|
||||
del h_
|
||||
|
||||
h3 = self.proj_out(h2)
|
||||
del h2
|
||||
|
||||
return h3
|
||||
|
||||
|
||||
def sdp_no_mem_attnblock_forward(self, x):
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||
return sdp_attnblock_forward(self, x)
|
||||
|
||||
|
||||
def sdp_attnblock_forward(self, h_):
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
|
||||
out = out.to(dtype)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||
out = self.proj_out(out)
|
||||
return out
|
||||
|
||||
def sub_quad_attnblock_forward(self, h_):
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||
out = self.proj_out(out)
|
||||
return out
|
||||
|
||||
|
||||
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
||||
bytes_per_token = torch.finfo(q.dtype).bits//8
|
||||
batch_x_heads, q_tokens, _ = q.shape
|
||||
_, k_tokens, _ = k.shape
|
||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
|
||||
if chunk_threshold is None:
|
||||
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
|
||||
elif chunk_threshold == 0:
|
||||
chunk_threshold_bytes = None
|
||||
else:
|
||||
chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
|
||||
|
||||
if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
|
||||
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
|
||||
elif kv_chunk_size_min == 0:
|
||||
kv_chunk_size_min = None
|
||||
|
||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||
# i.e. send it down the unchunked fast-path
|
||||
query_chunk_size = q_tokens
|
||||
kv_chunk_size = k_tokens
|
||||
|
||||
with devices.without_autocast(disable=q.dtype == v.dtype):
|
||||
return efficient_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_chunk_size=q_chunk_size,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min = kv_chunk_size_min,
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def get_xformers_flash_attention_op(q, k, v):
|
||||
if not shared.cmd_opts.xformers_flash_attention:
|
||||
return None
|
||||
|
||||
try:
|
||||
flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
|
||||
fw, bw = flash_attention_op
|
||||
if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
|
||||
return flash_attention_op
|
||||
except Exception as e:
|
||||
errors.display_once(e, "enabling flash attention")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def xformer_attn_forward(self, h_):
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(B, t.shape[1], 1, C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * 1, t.shape[1], C)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(
|
||||
q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(B, 1, out.shape[1], C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B, out.shape[1], C)
|
||||
)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
||||
out = self.proj_out(out)
|
||||
return out
|
||||
Loading…
Reference in New Issue