mirror of https://github.com/vladmandic/automatic
259 lines
11 KiB
Python
259 lines
11 KiB
Python
# credits: @tfernd https://github.com/tfernd/HyperTile
|
|
# based on: https://github.com/tfernd/HyperTile/tree/main/hyper_tile/utils.py + https://github.com/tfernd/HyperTile/tree/main/hyper_tile/hyper_tile.py
|
|
|
|
from __future__ import annotations
|
|
from typing import Callable
|
|
from functools import wraps, cache
|
|
from contextlib import contextmanager, nullcontext
|
|
import random
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
from einops import rearrange
|
|
from installer import log
|
|
|
|
|
|
# global variables to keep track of changing image size in multiple passes
|
|
height = None
|
|
width = None
|
|
max_h = 0
|
|
max_w = 0
|
|
error_reported = False
|
|
reset_needed = False
|
|
skip_hypertile = False
|
|
|
|
|
|
def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
|
"""
|
|
Finds h and w such that h*w = hw and h/w = aspect_ratio
|
|
We check all possible divisors of hw and return the closest to the aspect ratio
|
|
"""
|
|
divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
|
|
pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
|
|
ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
|
|
closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
|
|
closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
|
|
return closest_pair
|
|
|
|
|
|
@cache
|
|
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
|
"""
|
|
Finds h and w such that h*w = hw and h/w = aspect_ratio
|
|
"""
|
|
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
|
# find h and w such that h*w = hw and h/w = aspect_ratio
|
|
if h * w != hw:
|
|
w_candidate = hw / h
|
|
# check if w is an integer
|
|
if not w_candidate.is_integer():
|
|
h_candidate = hw / w
|
|
# check if h is an integer
|
|
if not h_candidate.is_integer():
|
|
return iterative_closest_divisors(hw, aspect_ratio)
|
|
else:
|
|
h = int(h_candidate)
|
|
else:
|
|
w = int(w_candidate)
|
|
return h, w
|
|
|
|
|
|
def possible_tile_sizes(dimension: int, tile_size: int, min_tile_size: int, tile_options: int) -> list[int]:
|
|
assert tile_options >= 1
|
|
min_tile_size = min(min_tile_size, tile_size, dimension)
|
|
# all divisors that are themselves divisible by 8 and give tile-size above min
|
|
n = torch.arange(1, dimension + 1)
|
|
n = n[dimension // n // 8 * 8 * n == dimension]
|
|
n = n[dimension // n >= min_tile_size]
|
|
pos = (dimension // n).sub(tile_size).abs().argsort()
|
|
pos = pos[:tile_options]
|
|
return n[pos].tolist()
|
|
|
|
|
|
def parse_list(x: list[int], /) -> str:
|
|
if len(x) == 0:
|
|
return str(x[0])
|
|
return str(x)
|
|
|
|
|
|
@contextmanager
|
|
def split_attention(layer: nn.Module, tile_size: int=256, min_tile_size: int=128, swap_size: int=1, depth: int=0):
|
|
# hijacks AttnBlock from ldm and attention from diffusers
|
|
global reset_needed # pylint: disable=global-statement
|
|
ar = height / width # Aspect ratio
|
|
reset_needed = True
|
|
nhs = possible_tile_sizes(height, tile_size, min_tile_size, swap_size) # possible sub-grids that fit into the image
|
|
nws = possible_tile_sizes(width, tile_size, min_tile_size, swap_size)
|
|
|
|
def reset_nhs():
|
|
nonlocal nhs, ar
|
|
ar = height / width # Aspect ratio
|
|
nhs = possible_tile_sizes(height, tile_size, min_tile_size, swap_size)
|
|
|
|
def reset_nws():
|
|
nonlocal nws, ar
|
|
ar = height / width # Aspect ratio
|
|
nws = possible_tile_sizes(width, tile_size, min_tile_size, swap_size)
|
|
|
|
def self_attn_forward(forward: Callable) -> Callable:
|
|
@wraps(forward)
|
|
def wrapper(*args, **kwargs):
|
|
global height, width, max_h, max_w, reset_needed, error_reported # pylint: disable=global-statement
|
|
if skip_hypertile:
|
|
return forward(*args, **kwargs)
|
|
x = args[0]
|
|
try:
|
|
nh = nhs[random.randint(0, len(nhs) - 1)]
|
|
nw = nws[random.randint(0, len(nws) - 1)]
|
|
except Exception as e:
|
|
if not error_reported:
|
|
error_reported = True
|
|
log.error(f'Hypertile calculate: width={width} height={height} {e}')
|
|
out = forward(x, *args[1:], **kwargs)
|
|
return out
|
|
if x.ndim == 4: # VAE
|
|
# TODO hypertile: vae breaks when using non-standard sizes
|
|
if nh * nw > 1:
|
|
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw)
|
|
out = forward(x, *args[1:], **kwargs)
|
|
if nh * nw > 1:
|
|
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
|
|
else: # Unet
|
|
hw = x.size(1)
|
|
h, w = round(math.sqrt(ar * hw)), round(math.sqrt(hw / ar))
|
|
# h, w = find_hw_candidates(hw, ar)
|
|
# dynamic height/width based on fact that first two forward calls contain actual height/width
|
|
# and reset if latest hw is larger since we're never downscaling in 2nd pass
|
|
if reset_needed:
|
|
reset_nhs()
|
|
reset_nws()
|
|
max_h = height
|
|
max_w = width
|
|
reset_needed = False
|
|
else:
|
|
if h > max_h:
|
|
height = 8 * h
|
|
max_h = max(max_h, h)
|
|
reset_nhs()
|
|
if w > max_w:
|
|
width = 8 * w
|
|
max_w = max(max_w, w)
|
|
reset_nws()
|
|
down_ratio = max(height // 8 // h, 1)
|
|
curr_depth = round(math.log(down_ratio, 2))
|
|
# scale-up the tile-size the deeper we go
|
|
nh = max(1, nh // down_ratio)
|
|
nw = max(1, nw // down_ratio)
|
|
do_split = curr_depth <= depth and h % nh == 0 and w % nw == 0 and nh * nw > 1
|
|
try:
|
|
if do_split:
|
|
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
|
out = forward(x, *args[1:], **kwargs)
|
|
if do_split:
|
|
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
|
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
|
|
except Exception as e:
|
|
if not error_reported:
|
|
error_reported = True
|
|
log.error(f'Hypertile apply: cls={layer.__class__} width={width} height={height} {e}')
|
|
out = forward(x, *args[1:], **kwargs)
|
|
return out
|
|
return wrapper
|
|
try: # hijack forward method and restore
|
|
for name, module in layer.named_modules():
|
|
if module.__class__.__qualname__ in ("Attention", "CrossAttention", "AttnBlock"):
|
|
if name.endswith("attn2") or name.endswith("attn_2"): # skip cross-attention layers
|
|
continue
|
|
setattr(module, "_original_forward", module.forward) # save original forward for recovery later # noqa: B010
|
|
setattr(module, "forward", self_attn_forward(module.forward)) # noqa: B010
|
|
yield
|
|
finally:
|
|
for _name, module in layer.named_modules():
|
|
if hasattr(module, "_original_forward"): # remove hijack
|
|
setattr(module, "forward", module._original_forward) # pylint: disable=protected-access # noqa: B010
|
|
del module._original_forward
|
|
|
|
|
|
def context_hypertile_vae(p):
|
|
from modules import shared
|
|
if p.sd_model is None or not shared.opts.hypertile_vae_enabled:
|
|
return nullcontext()
|
|
if shared.opts.cross_attention_optimization == 'Sub-quadratic':
|
|
shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
|
|
return nullcontext()
|
|
global max_h, max_w, error_reported # pylint: disable=global-statement
|
|
error_reported = False
|
|
error_reported = False
|
|
set_resolution(p)
|
|
max_h, max_w = 0, 0
|
|
vae = getattr(p.sd_model, "vae", None)
|
|
if height == 0 or width == 0:
|
|
log.warning('Hypertile VAE disabled: resolution unknown')
|
|
return nullcontext()
|
|
if height % 8 != 0 or width % 8 != 0:
|
|
log.warning(f'Hypertile VAE disabled: width={width} height={height} are not divisible by 8')
|
|
return nullcontext()
|
|
if vae is None:
|
|
return nullcontext()
|
|
else:
|
|
tile_size = shared.opts.hypertile_vae_tile if shared.opts.hypertile_vae_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
|
|
min_tile_size = shared.opts.hypertile_unet_min_tile if shared.opts.hypertile_unet_min_tile > 0 else 128
|
|
shared.log.info(f'Applying HyperTile: vae={min_tile_size}/{tile_size}')
|
|
p.extra_generation_params['Hypertile VAE'] = tile_size
|
|
return split_attention(vae, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=shared.opts.hypertile_vae_swap_size)
|
|
|
|
|
|
def context_hypertile_unet(p):
|
|
from modules import shared
|
|
if p.sd_model is None or not shared.opts.hypertile_unet_enabled:
|
|
return nullcontext()
|
|
if shared.opts.cross_attention_optimization == 'Sub-quadratic' and not shared.cmd_opts.experimental:
|
|
shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
|
|
return nullcontext()
|
|
global max_h, max_w, error_reported # pylint: disable=global-statement
|
|
error_reported = False
|
|
set_resolution(p)
|
|
max_h, max_w = 0, 0
|
|
unet = getattr(p.sd_model, "unet", None)
|
|
if height == 0 or width == 0:
|
|
log.warning('Hypertile VAE disabled: resolution unknown')
|
|
return nullcontext()
|
|
if height % 8 != 0 or width % 8 != 0:
|
|
log.warning(f'Hypertile UNet disabled: width={width} height={height} are not divisible by 8')
|
|
return nullcontext()
|
|
if unet is None:
|
|
# shared.log.warning('Hypertile UNet is enabled but no Unet model was found')
|
|
return nullcontext()
|
|
else:
|
|
tile_size = shared.opts.hypertile_unet_tile if shared.opts.hypertile_unet_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
|
|
min_tile_size = shared.opts.hypertile_unet_min_tile if shared.opts.hypertile_unet_min_tile > 0 else 128
|
|
shared.log.info(f'Applying HyperTile: unet={min_tile_size}/{tile_size}')
|
|
p.extra_generation_params['Hypertile UNet'] = tile_size
|
|
return split_attention(unet, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=shared.opts.hypertile_unet_swap_size, depth=shared.opts.hypertile_unet_depth)
|
|
|
|
|
|
def hypertile_set(p, hr=False):
|
|
from modules import shared
|
|
global error_reported, reset_needed, skip_hypertile # pylint: disable=global-statement
|
|
if not shared.opts.hypertile_unet_enabled:
|
|
return
|
|
error_reported = False
|
|
set_resolution(p, hr=hr)
|
|
skip_hypertile = shared.opts.hypertile_hires_only and not getattr(p, 'is_hr_pass', False)
|
|
reset_needed = True
|
|
|
|
|
|
def set_resolution(p, hr=False):
|
|
global height, width # pylint: disable=global-statement
|
|
if hr:
|
|
x = getattr(p, 'hr_upscale_to_x', 0)
|
|
y = getattr(p, 'hr_upscale_to_y', 0)
|
|
width = y if y > 0 else p.width
|
|
height = x if x > 0 else p.height
|
|
else:
|
|
width = p.width
|
|
height = p.height
|
|
if height == 0 or width == 0:
|
|
if hasattr(p, 'init_images') and isinstance(p.init_images, list) and len(p.init_images) > 0:
|
|
height, width = p.init_images[0].size
|