automatic/modules/sd_hijack_hypertile.py

259 lines
11 KiB
Python

# credits: @tfernd https://github.com/tfernd/HyperTile
# based on: https://github.com/tfernd/HyperTile/tree/main/hyper_tile/utils.py + https://github.com/tfernd/HyperTile/tree/main/hyper_tile/hyper_tile.py
from __future__ import annotations
from typing import Callable
from functools import wraps, cache
from contextlib import contextmanager, nullcontext
import random
import math
import torch
import torch.nn as nn
from einops import rearrange
from installer import log
# global variables to keep track of changing image size in multiple passes
height = None
width = None
max_h = 0
max_w = 0
error_reported = False
reset_needed = False
skip_hypertile = False
def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
"""
Finds h and w such that h*w = hw and h/w = aspect_ratio
We check all possible divisors of hw and return the closest to the aspect ratio
"""
divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
return closest_pair
@cache
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
"""
Finds h and w such that h*w = hw and h/w = aspect_ratio
"""
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
# find h and w such that h*w = hw and h/w = aspect_ratio
if h * w != hw:
w_candidate = hw / h
# check if w is an integer
if not w_candidate.is_integer():
h_candidate = hw / w
# check if h is an integer
if not h_candidate.is_integer():
return iterative_closest_divisors(hw, aspect_ratio)
else:
h = int(h_candidate)
else:
w = int(w_candidate)
return h, w
def possible_tile_sizes(dimension: int, tile_size: int, min_tile_size: int, tile_options: int) -> list[int]:
assert tile_options >= 1
min_tile_size = min(min_tile_size, tile_size, dimension)
# all divisors that are themselves divisible by 8 and give tile-size above min
n = torch.arange(1, dimension + 1)
n = n[dimension // n // 8 * 8 * n == dimension]
n = n[dimension // n >= min_tile_size]
pos = (dimension // n).sub(tile_size).abs().argsort()
pos = pos[:tile_options]
return n[pos].tolist()
def parse_list(x: list[int], /) -> str:
if len(x) == 0:
return str(x[0])
return str(x)
@contextmanager
def split_attention(layer: nn.Module, tile_size: int=256, min_tile_size: int=128, swap_size: int=1, depth: int=0):
# hijacks AttnBlock from ldm and attention from diffusers
global reset_needed # pylint: disable=global-statement
ar = height / width # Aspect ratio
reset_needed = True
nhs = possible_tile_sizes(height, tile_size, min_tile_size, swap_size) # possible sub-grids that fit into the image
nws = possible_tile_sizes(width, tile_size, min_tile_size, swap_size)
def reset_nhs():
nonlocal nhs, ar
ar = height / width # Aspect ratio
nhs = possible_tile_sizes(height, tile_size, min_tile_size, swap_size)
def reset_nws():
nonlocal nws, ar
ar = height / width # Aspect ratio
nws = possible_tile_sizes(width, tile_size, min_tile_size, swap_size)
def self_attn_forward(forward: Callable) -> Callable:
@wraps(forward)
def wrapper(*args, **kwargs):
global height, width, max_h, max_w, reset_needed, error_reported # pylint: disable=global-statement
if skip_hypertile:
return forward(*args, **kwargs)
x = args[0]
try:
nh = nhs[random.randint(0, len(nhs) - 1)]
nw = nws[random.randint(0, len(nws) - 1)]
except Exception as e:
if not error_reported:
error_reported = True
log.error(f'Hypertile calculate: width={width} height={height} {e}')
out = forward(x, *args[1:], **kwargs)
return out
if x.ndim == 4: # VAE
# TODO hypertile: vae breaks when using non-standard sizes
if nh * nw > 1:
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw)
out = forward(x, *args[1:], **kwargs)
if nh * nw > 1:
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
else: # Unet
hw = x.size(1)
h, w = round(math.sqrt(ar * hw)), round(math.sqrt(hw / ar))
# h, w = find_hw_candidates(hw, ar)
# dynamic height/width based on fact that first two forward calls contain actual height/width
# and reset if latest hw is larger since we're never downscaling in 2nd pass
if reset_needed:
reset_nhs()
reset_nws()
max_h = height
max_w = width
reset_needed = False
else:
if h > max_h:
height = 8 * h
max_h = max(max_h, h)
reset_nhs()
if w > max_w:
width = 8 * w
max_w = max(max_w, w)
reset_nws()
down_ratio = max(height // 8 // h, 1)
curr_depth = round(math.log(down_ratio, 2))
# scale-up the tile-size the deeper we go
nh = max(1, nh // down_ratio)
nw = max(1, nw // down_ratio)
do_split = curr_depth <= depth and h % nh == 0 and w % nw == 0 and nh * nw > 1
try:
if do_split:
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
out = forward(x, *args[1:], **kwargs)
if do_split:
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
except Exception as e:
if not error_reported:
error_reported = True
log.error(f'Hypertile apply: cls={layer.__class__} width={width} height={height} {e}')
out = forward(x, *args[1:], **kwargs)
return out
return wrapper
try: # hijack forward method and restore
for name, module in layer.named_modules():
if module.__class__.__qualname__ in ("Attention", "CrossAttention", "AttnBlock"):
if name.endswith("attn2") or name.endswith("attn_2"): # skip cross-attention layers
continue
setattr(module, "_original_forward", module.forward) # save original forward for recovery later # noqa: B010
setattr(module, "forward", self_attn_forward(module.forward)) # noqa: B010
yield
finally:
for _name, module in layer.named_modules():
if hasattr(module, "_original_forward"): # remove hijack
setattr(module, "forward", module._original_forward) # pylint: disable=protected-access # noqa: B010
del module._original_forward
def context_hypertile_vae(p):
from modules import shared
if p.sd_model is None or not shared.opts.hypertile_vae_enabled:
return nullcontext()
if shared.opts.cross_attention_optimization == 'Sub-quadratic':
shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
return nullcontext()
global max_h, max_w, error_reported # pylint: disable=global-statement
error_reported = False
error_reported = False
set_resolution(p)
max_h, max_w = 0, 0
vae = getattr(p.sd_model, "vae", None)
if height == 0 or width == 0:
log.warning('Hypertile VAE disabled: resolution unknown')
return nullcontext()
if height % 8 != 0 or width % 8 != 0:
log.warning(f'Hypertile VAE disabled: width={width} height={height} are not divisible by 8')
return nullcontext()
if vae is None:
return nullcontext()
else:
tile_size = shared.opts.hypertile_vae_tile if shared.opts.hypertile_vae_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
min_tile_size = shared.opts.hypertile_unet_min_tile if shared.opts.hypertile_unet_min_tile > 0 else 128
shared.log.info(f'Applying HyperTile: vae={min_tile_size}/{tile_size}')
p.extra_generation_params['Hypertile VAE'] = tile_size
return split_attention(vae, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=shared.opts.hypertile_vae_swap_size)
def context_hypertile_unet(p):
from modules import shared
if p.sd_model is None or not shared.opts.hypertile_unet_enabled:
return nullcontext()
if shared.opts.cross_attention_optimization == 'Sub-quadratic' and not shared.cmd_opts.experimental:
shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
return nullcontext()
global max_h, max_w, error_reported # pylint: disable=global-statement
error_reported = False
set_resolution(p)
max_h, max_w = 0, 0
unet = getattr(p.sd_model, "unet", None)
if height == 0 or width == 0:
log.warning('Hypertile VAE disabled: resolution unknown')
return nullcontext()
if height % 8 != 0 or width % 8 != 0:
log.warning(f'Hypertile UNet disabled: width={width} height={height} are not divisible by 8')
return nullcontext()
if unet is None:
# shared.log.warning('Hypertile UNet is enabled but no Unet model was found')
return nullcontext()
else:
tile_size = shared.opts.hypertile_unet_tile if shared.opts.hypertile_unet_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
min_tile_size = shared.opts.hypertile_unet_min_tile if shared.opts.hypertile_unet_min_tile > 0 else 128
shared.log.info(f'Applying HyperTile: unet={min_tile_size}/{tile_size}')
p.extra_generation_params['Hypertile UNet'] = tile_size
return split_attention(unet, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=shared.opts.hypertile_unet_swap_size, depth=shared.opts.hypertile_unet_depth)
def hypertile_set(p, hr=False):
from modules import shared
global error_reported, reset_needed, skip_hypertile # pylint: disable=global-statement
if not shared.opts.hypertile_unet_enabled:
return
error_reported = False
set_resolution(p, hr=hr)
skip_hypertile = shared.opts.hypertile_hires_only and not getattr(p, 'is_hr_pass', False)
reset_needed = True
def set_resolution(p, hr=False):
global height, width # pylint: disable=global-statement
if hr:
x = getattr(p, 'hr_upscale_to_x', 0)
y = getattr(p, 'hr_upscale_to_y', 0)
width = y if y > 0 else p.width
height = x if x > 0 else p.height
else:
width = p.width
height = p.height
if height == 0 or width == 0:
if hasattr(p, 'init_images') and isinstance(p.init_images, list) and len(p.init_images) > 0:
height, width = p.init_images[0].size