759 lines
30 KiB
Python
759 lines
30 KiB
Python
'''
|
|
# ------------------------------------------------------------------------
|
|
#
|
|
# Tiled VAE
|
|
#
|
|
# Introducing a revolutionary new optimization designed to make
|
|
# the VAE work with giant images on limited VRAM!
|
|
# Say goodbye to the frustration of OOM and hello to seamless output!
|
|
#
|
|
# ------------------------------------------------------------------------
|
|
#
|
|
# This script is a wild hack that splits the image into tiles,
|
|
# encodes each tile separately, and merges the result back together.
|
|
#
|
|
# Advantages:
|
|
# - The VAE can now work with giant images on limited VRAM
|
|
# (~10 GB for 8K images!)
|
|
# - The merged output is completely seamless without any post-processing.
|
|
#
|
|
# Drawbacks:
|
|
# - NaNs always appear in for 8k images when you use fp16 (half) VAE
|
|
# You must use --no-half-vae to disable half VAE for that giant image.
|
|
# - The gradient calculation is not compatible with this hack. It
|
|
# will break any backward() or torch.autograd.grad() that passes VAE.
|
|
# (But you can still use the VAE to generate training data.)
|
|
#
|
|
# How it works:
|
|
# 1. The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder.
|
|
# 2. When Fast Mode is disabled:
|
|
# 1. The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile.
|
|
# 2. When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile.
|
|
# 3. After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues.
|
|
# 4. A zigzag execution order is used to reduce unnecessary data transfer.
|
|
# 3. When Fast Mode is enabled:
|
|
# 1. The original input is downsampled and passed to a separate task queue.
|
|
# 2. Its group norm parameters are recorded and used by all tiles' task queues.
|
|
# 3. Each tile is separately processed without any RAM-VRAM data transfer.
|
|
# 4. After all tiles are processed, tiles are written to a result buffer and returned.
|
|
# Encoder color fix = only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode.
|
|
#
|
|
# Enjoy!
|
|
#
|
|
# @Author: LI YI @ Nanyang Technological University - Singapore
|
|
# @Date: 2023-03-02
|
|
# @License: CC BY-NC-SA 4.0
|
|
#
|
|
# Please give me a star if you like this project!
|
|
#
|
|
# -------------------------------------------------------------------------
|
|
'''
|
|
|
|
import gc
|
|
import math
|
|
from time import time
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
import torch.version
|
|
import torch.nn.functional as F
|
|
import gradio as gr
|
|
|
|
import modules.scripts as scripts
|
|
import modules.devices as devices
|
|
from modules.shared import state, opts
|
|
from modules.ui import gr_show
|
|
from modules.processing import opt_f
|
|
from modules.sd_vae_approx import cheap_approximation
|
|
from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock
|
|
|
|
from tile_utils.attn import get_attn_func
|
|
from tile_utils.typing import Processing
|
|
|
|
if hasattr(opts, 'hypertile_enable_unet'): # webui >= 1.7
|
|
from modules.ui_components import InputAccordion
|
|
else:
|
|
InputAccordion = None
|
|
|
|
|
|
def get_rcmd_enc_tsize():
|
|
if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]:
|
|
total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20
|
|
if total_memory > 16*1000: ENCODER_TILE_SIZE = 3072
|
|
elif total_memory > 12*1000: ENCODER_TILE_SIZE = 2048
|
|
elif total_memory > 8*1000: ENCODER_TILE_SIZE = 1536
|
|
else: ENCODER_TILE_SIZE = 960
|
|
else: ENCODER_TILE_SIZE = 512
|
|
return ENCODER_TILE_SIZE
|
|
|
|
|
|
def get_rcmd_dec_tsize():
|
|
if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]:
|
|
total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20
|
|
if total_memory > 30*1000: DECODER_TILE_SIZE = 256
|
|
elif total_memory > 16*1000: DECODER_TILE_SIZE = 192
|
|
elif total_memory > 12*1000: DECODER_TILE_SIZE = 128
|
|
elif total_memory > 8*1000: DECODER_TILE_SIZE = 96
|
|
else: DECODER_TILE_SIZE = 64
|
|
else: DECODER_TILE_SIZE = 64
|
|
return DECODER_TILE_SIZE
|
|
|
|
|
|
def inplace_nonlinearity(x):
|
|
# Test: fix for Nans
|
|
return F.silu(x, inplace=True)
|
|
|
|
|
|
def attn2task(task_queue, net):
|
|
attn_forward = get_attn_func()
|
|
task_queue.append(('store_res', lambda x: x))
|
|
task_queue.append(('pre_norm', net.norm))
|
|
task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
|
|
task_queue.append(['add_res', None])
|
|
|
|
|
|
def resblock2task(queue, block):
|
|
"""
|
|
Turn a ResNetBlock into a sequence of tasks and append to the task queue
|
|
|
|
@param queue: the target task queue
|
|
@param block: ResNetBlock
|
|
|
|
"""
|
|
if block.in_channels != block.out_channels:
|
|
if block.use_conv_shortcut:
|
|
queue.append(('store_res', block.conv_shortcut))
|
|
else:
|
|
queue.append(('store_res', block.nin_shortcut))
|
|
else:
|
|
queue.append(('store_res', lambda x: x))
|
|
queue.append(('pre_norm', block.norm1))
|
|
queue.append(('silu', inplace_nonlinearity))
|
|
queue.append(('conv1', block.conv1))
|
|
queue.append(('pre_norm', block.norm2))
|
|
queue.append(('silu', inplace_nonlinearity))
|
|
queue.append(('conv2', block.conv2))
|
|
queue.append(['add_res', None])
|
|
|
|
|
|
def build_sampling(task_queue, net, is_decoder):
|
|
"""
|
|
Build the sampling part of a task queue
|
|
@param task_queue: the target task queue
|
|
@param net: the network
|
|
@param is_decoder: currently building decoder or encoder
|
|
"""
|
|
if is_decoder:
|
|
resblock2task(task_queue, net.mid.block_1)
|
|
attn2task(task_queue, net.mid.attn_1)
|
|
resblock2task(task_queue, net.mid.block_2)
|
|
resolution_iter = reversed(range(net.num_resolutions))
|
|
block_ids = net.num_res_blocks + 1
|
|
condition = 0
|
|
module = net.up
|
|
func_name = 'upsample'
|
|
else:
|
|
resolution_iter = range(net.num_resolutions)
|
|
block_ids = net.num_res_blocks
|
|
condition = net.num_resolutions - 1
|
|
module = net.down
|
|
func_name = 'downsample'
|
|
|
|
for i_level in resolution_iter:
|
|
for i_block in range(block_ids):
|
|
resblock2task(task_queue, module[i_level].block[i_block])
|
|
if i_level != condition:
|
|
task_queue.append((func_name, getattr(module[i_level], func_name)))
|
|
|
|
if not is_decoder:
|
|
resblock2task(task_queue, net.mid.block_1)
|
|
attn2task(task_queue, net.mid.attn_1)
|
|
resblock2task(task_queue, net.mid.block_2)
|
|
|
|
|
|
def build_task_queue(net, is_decoder):
|
|
"""
|
|
Build a single task queue for the encoder or decoder
|
|
@param net: the VAE decoder or encoder network
|
|
@param is_decoder: currently building decoder or encoder
|
|
@return: the task queue
|
|
"""
|
|
task_queue = []
|
|
task_queue.append(('conv_in', net.conv_in))
|
|
|
|
# construct the sampling part of the task queue
|
|
# because encoder and decoder share the same architecture, we extract the sampling part
|
|
build_sampling(task_queue, net, is_decoder)
|
|
|
|
if not is_decoder or not net.give_pre_end:
|
|
task_queue.append(('pre_norm', net.norm_out))
|
|
task_queue.append(('silu', inplace_nonlinearity))
|
|
task_queue.append(('conv_out', net.conv_out))
|
|
if is_decoder and net.tanh_out:
|
|
task_queue.append(('tanh', torch.tanh))
|
|
|
|
return task_queue
|
|
|
|
|
|
def clone_task_queue(task_queue):
|
|
"""
|
|
Clone a task queue
|
|
@param task_queue: the task queue to be cloned
|
|
@return: the cloned task queue
|
|
"""
|
|
return [[item for item in task] for task in task_queue]
|
|
|
|
|
|
def get_var_mean(input, num_groups, eps=1e-6):
|
|
"""
|
|
Get mean and var for group norm
|
|
"""
|
|
b, c = input.size(0), input.size(1)
|
|
channel_in_group = int(c/num_groups)
|
|
input_reshaped = input.contiguous().view(1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
|
var, mean = torch.var_mean(input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
|
|
return var, mean
|
|
|
|
|
|
def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
|
|
"""
|
|
Custom group norm with fixed mean and var
|
|
|
|
@param input: input tensor
|
|
@param num_groups: number of groups. by default, num_groups = 32
|
|
@param mean: mean, must be pre-calculated by get_var_mean
|
|
@param var: var, must be pre-calculated by get_var_mean
|
|
@param weight: weight, should be fetched from the original group norm
|
|
@param bias: bias, should be fetched from the original group norm
|
|
@param eps: epsilon, by default, eps = 1e-6 to match the original group norm
|
|
|
|
@return: normalized tensor
|
|
"""
|
|
b, c = input.size(0), input.size(1)
|
|
channel_in_group = int(c/num_groups)
|
|
input_reshaped = input.contiguous().view(
|
|
1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
|
|
|
out = F.batch_norm(input_reshaped, mean.to(input), var.to(input), weight=None, bias=None, training=False, momentum=0, eps=eps)
|
|
out = out.view(b, c, *input.size()[2:])
|
|
|
|
# post affine transform
|
|
if weight is not None:
|
|
out *= weight.view(1, -1, 1, 1)
|
|
if bias is not None:
|
|
out += bias.view(1, -1, 1, 1)
|
|
return out
|
|
|
|
|
|
def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
|
|
"""
|
|
Crop the valid region from the tile
|
|
@param x: input tile
|
|
@param input_bbox: original input bounding box
|
|
@param target_bbox: output bounding box
|
|
@param scale: scale factor
|
|
@return: cropped tile
|
|
"""
|
|
padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
|
|
margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
|
|
return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
|
|
|
|
|
|
# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
|
|
|
|
def perfcount(fn):
|
|
def wrapper(*args, **kwargs):
|
|
ts = time()
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.reset_peak_memory_stats(devices.device)
|
|
devices.torch_gc()
|
|
gc.collect()
|
|
|
|
ret = fn(*args, **kwargs)
|
|
|
|
devices.torch_gc()
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
|
|
print(f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
|
|
else:
|
|
print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
|
|
|
|
return ret
|
|
return wrapper
|
|
|
|
# ↑↑↑ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↑↑↑
|
|
|
|
|
|
class GroupNormParam:
|
|
|
|
def __init__(self):
|
|
self.var_list = []
|
|
self.mean_list = []
|
|
self.pixel_list = []
|
|
self.weight = None
|
|
self.bias = None
|
|
|
|
def add_tile(self, tile, layer):
|
|
var, mean = get_var_mean(tile, 32)
|
|
# For giant images, the variance can be larger than max float16
|
|
# In this case we create a copy to float32
|
|
if var.dtype == torch.float16 and var.isinf().any():
|
|
fp32_tile = tile.float()
|
|
var, mean = get_var_mean(fp32_tile, 32)
|
|
# ============= DEBUG: test for infinite =============
|
|
# if torch.isinf(var).any():
|
|
# print('var: ', var)
|
|
# ====================================================
|
|
self.var_list.append(var)
|
|
self.mean_list.append(mean)
|
|
self.pixel_list.append(
|
|
tile.shape[2]*tile.shape[3])
|
|
if hasattr(layer, 'weight'):
|
|
self.weight = layer.weight
|
|
self.bias = layer.bias
|
|
else:
|
|
self.weight = None
|
|
self.bias = None
|
|
|
|
def summary(self):
|
|
"""
|
|
summarize the mean and var and return a function
|
|
that apply group norm on each tile
|
|
"""
|
|
if len(self.var_list) == 0: return None
|
|
|
|
var = torch.vstack(self.var_list)
|
|
mean = torch.vstack(self.mean_list)
|
|
max_value = max(self.pixel_list)
|
|
pixels = torch.tensor(self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
|
|
sum_pixels = torch.sum(pixels)
|
|
pixels = pixels.unsqueeze(1) / sum_pixels
|
|
var = torch.sum(var * pixels, dim=0)
|
|
mean = torch.sum(mean * pixels, dim=0)
|
|
return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
|
|
|
|
@staticmethod
|
|
def from_tile(tile, norm):
|
|
"""
|
|
create a function from a single tile without summary
|
|
"""
|
|
var, mean = get_var_mean(tile, 32)
|
|
if var.dtype == torch.float16 and var.isinf().any():
|
|
fp32_tile = tile.float()
|
|
var, mean = get_var_mean(fp32_tile, 32)
|
|
# if it is a macbook, we need to convert back to float16
|
|
if var.device.type == 'mps':
|
|
# clamp to avoid overflow
|
|
var = torch.clamp(var, 0, 60000)
|
|
var = var.half()
|
|
mean = mean.half()
|
|
if hasattr(norm, 'weight'):
|
|
weight = norm.weight
|
|
bias = norm.bias
|
|
else:
|
|
weight = None
|
|
bias = None
|
|
|
|
def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
|
|
return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
|
|
return group_norm_func
|
|
|
|
|
|
class VAEHook:
|
|
|
|
def __init__(self, net, tile_size, is_decoder:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool, to_gpu:bool=False):
|
|
self.net = net # encoder | decoder
|
|
self.tile_size = tile_size
|
|
self.is_decoder = is_decoder
|
|
self.fast_mode = (fast_encoder and not is_decoder) or (fast_decoder and is_decoder)
|
|
self.color_fix = color_fix and not is_decoder
|
|
self.to_gpu = to_gpu
|
|
self.pad = 11 if is_decoder else 32 # FIXME: magic number
|
|
|
|
def __call__(self, x):
|
|
original_device = next(self.net.parameters()).device
|
|
try:
|
|
if self.to_gpu:
|
|
self.net = self.net.to(devices.get_optimal_device())
|
|
|
|
B, C, H, W = x.shape
|
|
if max(H, W) <= self.pad * 2 + self.tile_size:
|
|
print("[Tiled VAE]: the input size is tiny and unnecessary to tile.")
|
|
return self.net.original_forward(x)
|
|
else:
|
|
return self.vae_tile_forward(x)
|
|
finally:
|
|
self.net = self.net.to(original_device)
|
|
|
|
def get_best_tile_size(self, lowerbound, upperbound):
|
|
"""
|
|
Get the best tile size for GPU memory
|
|
"""
|
|
divider = 32
|
|
while divider >= 2:
|
|
remainer = lowerbound % divider
|
|
if remainer == 0:
|
|
return lowerbound
|
|
candidate = lowerbound - remainer + divider
|
|
if candidate <= upperbound:
|
|
return candidate
|
|
divider //= 2
|
|
return lowerbound
|
|
|
|
def split_tiles(self, h, w):
|
|
"""
|
|
Tool function to split the image into tiles
|
|
@param h: height of the image
|
|
@param w: width of the image
|
|
@return: tile_input_bboxes, tile_output_bboxes
|
|
"""
|
|
tile_input_bboxes, tile_output_bboxes = [], []
|
|
tile_size = self.tile_size
|
|
pad = self.pad
|
|
num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
|
|
num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
|
|
# If any of the numbers are 0, we let it be 1
|
|
# This is to deal with long and thin images
|
|
num_height_tiles = max(num_height_tiles, 1)
|
|
num_width_tiles = max(num_width_tiles, 1)
|
|
|
|
# Suggestions from https://github.com/Kahsolt: auto shrink the tile size
|
|
real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
|
|
real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
|
|
real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
|
|
real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
|
|
|
|
print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
|
|
f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
|
|
|
|
for i in range(num_height_tiles):
|
|
for j in range(num_width_tiles):
|
|
# bbox: [x1, x2, y1, y2]
|
|
# the padding is is unnessary for image borders. So we directly start from (32, 32)
|
|
input_bbox = [
|
|
pad + j * real_tile_width,
|
|
min(pad + (j + 1) * real_tile_width, w),
|
|
pad + i * real_tile_height,
|
|
min(pad + (i + 1) * real_tile_height, h),
|
|
]
|
|
|
|
# if the output bbox is close to the image boundary, we extend it to the image boundary
|
|
output_bbox = [
|
|
input_bbox[0] if input_bbox[0] > pad else 0,
|
|
input_bbox[1] if input_bbox[1] < w - pad else w,
|
|
input_bbox[2] if input_bbox[2] > pad else 0,
|
|
input_bbox[3] if input_bbox[3] < h - pad else h,
|
|
]
|
|
|
|
# scale to get the final output bbox
|
|
output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
|
|
tile_output_bboxes.append(output_bbox)
|
|
|
|
# indistinguishable expand the input bbox by pad pixels
|
|
tile_input_bboxes.append([
|
|
max(0, input_bbox[0] - pad),
|
|
min(w, input_bbox[1] + pad),
|
|
max(0, input_bbox[2] - pad),
|
|
min(h, input_bbox[3] + pad),
|
|
])
|
|
|
|
return tile_input_bboxes, tile_output_bboxes
|
|
|
|
@torch.no_grad()
|
|
def estimate_group_norm(self, z, task_queue, color_fix):
|
|
device = z.device
|
|
tile = z
|
|
last_id = len(task_queue) - 1
|
|
while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
|
|
last_id -= 1
|
|
if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
|
|
raise ValueError('No group norm found in the task queue')
|
|
# estimate until the last group norm
|
|
for i in range(last_id + 1):
|
|
task = task_queue[i]
|
|
if task[0] == 'pre_norm':
|
|
group_norm_func = GroupNormParam.from_tile(tile, task[1])
|
|
task_queue[i] = ('apply_norm', group_norm_func)
|
|
if i == last_id:
|
|
return True
|
|
tile = group_norm_func(tile)
|
|
elif task[0] == 'store_res':
|
|
task_id = i + 1
|
|
while task_id < last_id and task_queue[task_id][0] != 'add_res':
|
|
task_id += 1
|
|
if task_id >= last_id:
|
|
continue
|
|
task_queue[task_id][1] = task[1](tile)
|
|
elif task[0] == 'add_res':
|
|
tile += task[1].to(device)
|
|
task[1] = None
|
|
elif color_fix and task[0] == 'downsample':
|
|
for j in range(i, last_id + 1):
|
|
if task_queue[j][0] == 'store_res':
|
|
task_queue[j] = ('store_res_cpu', task_queue[j][1])
|
|
return True
|
|
else:
|
|
tile = task[1](tile)
|
|
try:
|
|
devices.test_for_nans(tile, "vae")
|
|
except:
|
|
print(f'Nan detected in fast mode estimation. Fast mode disabled.')
|
|
return False
|
|
|
|
raise IndexError('Should not reach here')
|
|
|
|
@perfcount
|
|
@torch.no_grad()
|
|
def vae_tile_forward(self, z):
|
|
"""
|
|
Decode a latent vector z into an image in a tiled manner.
|
|
@param z: latent vector
|
|
@return: image
|
|
"""
|
|
device = next(self.net.parameters()).device
|
|
dtype = next(self.net.parameters()).dtype
|
|
net = self.net
|
|
tile_size = self.tile_size
|
|
is_decoder = self.is_decoder
|
|
|
|
z = z.detach() # detach the input to avoid backprop
|
|
|
|
N, height, width = z.shape[0], z.shape[2], z.shape[3]
|
|
net.last_z_shape = z.shape
|
|
|
|
# Split the input into tiles and build a task queue for each tile
|
|
print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
|
|
|
|
in_bboxes, out_bboxes = self.split_tiles(height, width)
|
|
|
|
# Prepare tiles by split the input latents
|
|
tiles = []
|
|
for input_bbox in in_bboxes:
|
|
tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
|
|
tiles.append(tile)
|
|
|
|
num_tiles = len(tiles)
|
|
num_completed = 0
|
|
|
|
# Build task queues
|
|
single_task_queue = build_task_queue(net, is_decoder)
|
|
if self.fast_mode:
|
|
# Fast mode: downsample the input image to the tile size,
|
|
# then estimate the group norm parameters on the downsampled image
|
|
scale_factor = tile_size / max(height, width)
|
|
z = z.to(device)
|
|
downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
|
|
# use nearest-exact to keep statictics as close as possible
|
|
print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
|
|
|
|
# ======= Special thanks to @Kahsolt for distribution shift issue ======= #
|
|
# The downsampling will heavily distort its mean and std, so we need to recover it.
|
|
std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
|
|
std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
|
|
downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
|
|
del std_old, mean_old, std_new, mean_new
|
|
# occasionally the std_new is too small or too large, which exceeds the range of float16
|
|
# so we need to clamp it to max z's range.
|
|
downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
|
|
estimate_task_queue = clone_task_queue(single_task_queue)
|
|
if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
|
|
single_task_queue = estimate_task_queue
|
|
del downsampled_z
|
|
|
|
task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
|
|
|
|
# Dummy result
|
|
result = None
|
|
result_approx = None
|
|
try:
|
|
with devices.autocast():
|
|
result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()
|
|
except: pass
|
|
# Free memory of input latent tensor
|
|
del z
|
|
|
|
# Task queue execution
|
|
pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
|
|
|
|
# execute the task back and forth when switch tiles so that we always
|
|
# keep one tile on the GPU to reduce unnecessary data transfer
|
|
forward = True
|
|
interrupted = False
|
|
#state.interrupted = interrupted
|
|
while True:
|
|
if state.interrupted: interrupted = True ; break
|
|
|
|
group_norm_param = GroupNormParam()
|
|
for i in range(num_tiles) if forward else reversed(range(num_tiles)):
|
|
if state.interrupted: interrupted = True ; break
|
|
|
|
tile = tiles[i].to(device)
|
|
input_bbox = in_bboxes[i]
|
|
task_queue = task_queues[i]
|
|
|
|
interrupted = False
|
|
while len(task_queue) > 0:
|
|
if state.interrupted: interrupted = True ; break
|
|
|
|
# DEBUG: current task
|
|
# print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
|
|
task = task_queue.pop(0)
|
|
if task[0] == 'pre_norm':
|
|
group_norm_param.add_tile(tile, task[1])
|
|
break
|
|
elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
|
|
task_id = 0
|
|
res = task[1](tile)
|
|
if not self.fast_mode or task[0] == 'store_res_cpu':
|
|
res = res.cpu()
|
|
while task_queue[task_id][0] != 'add_res':
|
|
task_id += 1
|
|
task_queue[task_id][1] = res
|
|
elif task[0] == 'add_res':
|
|
tile += task[1].to(device)
|
|
task[1] = None
|
|
else:
|
|
tile = task[1](tile)
|
|
pbar.update(1)
|
|
|
|
if interrupted: break
|
|
|
|
# check for NaNs in the tile.
|
|
# If there are NaNs, we abort the process to save user's time
|
|
devices.test_for_nans(tile, "vae")
|
|
|
|
if len(task_queue) == 0:
|
|
tiles[i] = None
|
|
num_completed += 1
|
|
if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
|
|
result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
|
|
result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
|
|
del tile
|
|
elif i == num_tiles - 1 and forward:
|
|
forward = False
|
|
tiles[i] = tile
|
|
elif i == 0 and not forward:
|
|
forward = True
|
|
tiles[i] = tile
|
|
else:
|
|
tiles[i] = tile.cpu()
|
|
del tile
|
|
|
|
if interrupted: break
|
|
if num_completed == num_tiles: break
|
|
|
|
# insert the group norm task to the head of each task queue
|
|
group_norm_func = group_norm_param.summary()
|
|
if group_norm_func is not None:
|
|
for i in range(num_tiles):
|
|
task_queue = task_queues[i]
|
|
task_queue.insert(0, ('apply_norm', group_norm_func))
|
|
|
|
# Done!
|
|
pbar.close()
|
|
return result.to(dtype) if result is not None else result_approx.to(device, dtype=dtype)
|
|
|
|
|
|
class Script(scripts.Script):
|
|
|
|
def __init__(self):
|
|
self.hooked = False
|
|
|
|
def title(self):
|
|
return "Tiled VAE"
|
|
|
|
def show(self, is_img2img):
|
|
return scripts.AlwaysVisible
|
|
|
|
def ui(self, is_img2img):
|
|
tab = 't2i' if not is_img2img else 'i2i'
|
|
uid = lambda name: f'MD-{tab}-{name}'
|
|
|
|
with (
|
|
InputAccordion(False, label='Tiled VAE', elem_id=f'MDV-{tab}-enabled') if InputAccordion
|
|
else gr.Accordion('Tiled VAE', open=False, elem_id=f'MDV-{tab}')
|
|
as enabled
|
|
):
|
|
with gr.Row() as tab_enable:
|
|
if not InputAccordion:
|
|
enabled = gr.Checkbox(label='Enable Tiled VAE', value=False, elem_id=uid('enable'))
|
|
vae_to_gpu = gr.Checkbox(label='Move VAE to GPU (if possible)', value=True, elem_id=uid('vae2gpu'))
|
|
|
|
gr.HTML('<p style="margin-bottom:0.8em"> Recommended to set tile sizes as large as possible before got CUDA error: out of memory. </p>')
|
|
with gr.Row() as tab_size:
|
|
encoder_tile_size = gr.Slider(label='Encoder Tile Size', minimum=256, maximum=4096, step=16, value=get_rcmd_enc_tsize(), elem_id=uid('enc-size'))
|
|
decoder_tile_size = gr.Slider(label='Decoder Tile Size', minimum=48, maximum=512, step=16, value=get_rcmd_dec_tsize(), elem_id=uid('dec-size'))
|
|
reset = gr.Button(value='↻ Reset', variant='tool')
|
|
reset.click(fn=lambda: [get_rcmd_enc_tsize(), get_rcmd_dec_tsize()], outputs=[encoder_tile_size, decoder_tile_size], show_progress=False)
|
|
|
|
with gr.Row() as tab_param:
|
|
fast_encoder = gr.Checkbox(label='Fast Encoder', value=True, elem_id=uid('fastenc'))
|
|
color_fix = gr.Checkbox(label='Fast Encoder Color Fix', value=False, visible=True, elem_id=uid('fastenc-colorfix'))
|
|
fast_decoder = gr.Checkbox(label='Fast Decoder', value=True, elem_id=uid('fastdec'))
|
|
|
|
fast_encoder.change(fn=gr_show, inputs=fast_encoder, outputs=color_fix, show_progress=False)
|
|
|
|
return [
|
|
enabled,
|
|
encoder_tile_size, decoder_tile_size,
|
|
vae_to_gpu, fast_decoder, fast_encoder, color_fix,
|
|
]
|
|
|
|
def process(self, p:Processing,
|
|
enabled:bool,
|
|
encoder_tile_size:int, decoder_tile_size:int,
|
|
vae_to_gpu:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool
|
|
):
|
|
|
|
# for shorthand
|
|
vae = p.sd_model.first_stage_model
|
|
encoder = vae.encoder
|
|
decoder = vae.decoder
|
|
|
|
# undo hijack if disabled (in cases last time crashed)
|
|
if not enabled:
|
|
if self.hooked:
|
|
if isinstance(encoder.forward, VAEHook):
|
|
encoder.forward.net = None
|
|
encoder.forward = encoder.original_forward
|
|
if isinstance(decoder.forward, VAEHook):
|
|
decoder.forward.net = None
|
|
decoder.forward = decoder.original_forward
|
|
self.hooked = False
|
|
return
|
|
|
|
if devices.get_optimal_device_name().startswith('cuda') and vae.device == devices.cpu and not vae_to_gpu:
|
|
print("[Tiled VAE] warn: VAE is not on GPU, check 'Move VAE to GPU' if possible.")
|
|
|
|
# do hijack
|
|
kwargs = {
|
|
'fast_decoder': fast_decoder,
|
|
'fast_encoder': fast_encoder,
|
|
'color_fix': color_fix,
|
|
'to_gpu': vae_to_gpu,
|
|
}
|
|
|
|
# save original forward (only once)
|
|
if not hasattr(encoder, 'original_forward'): setattr(encoder, 'original_forward', encoder.forward)
|
|
if not hasattr(decoder, 'original_forward'): setattr(decoder, 'original_forward', decoder.forward)
|
|
|
|
self.hooked = True
|
|
|
|
encoder.forward = VAEHook(encoder, encoder_tile_size, is_decoder=False, **kwargs)
|
|
decoder.forward = VAEHook(decoder, decoder_tile_size, is_decoder=True, **kwargs)
|
|
|
|
def postprocess(self, p:Processing, processed, enabled:bool, *args):
|
|
if not enabled: return
|
|
|
|
vae = p.sd_model.first_stage_model
|
|
encoder = vae.encoder
|
|
decoder = vae.decoder
|
|
if isinstance(encoder.forward, VAEHook):
|
|
encoder.forward.net = None
|
|
encoder.forward = encoder.original_forward
|
|
if isinstance(decoder.forward, VAEHook):
|
|
decoder.forward.net = None
|
|
decoder.forward = decoder.original_forward
|