mirror of https://github.com/vladmandic/automatic
merge: modules/sd_hijack_hypertile.py
parent
cb224651a0
commit
35803746df
|
|
@ -2,7 +2,7 @@
|
|||
# based on: https://github.com/tfernd/HyperTile/tree/main/hyper_tile/utils.py + https://github.com/tfernd/HyperTile/tree/main/hyper_tile/hyper_tile.py
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Callable
|
||||
from functools import wraps, cache
|
||||
from contextlib import contextmanager, nullcontext
|
||||
import random
|
||||
|
|
@ -10,9 +10,15 @@ import math
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from modules.logger import log
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from installer import log
|
||||
|
||||
|
||||
def _p_or_opt(p, key):
|
||||
val = getattr(p, key, None)
|
||||
if val is not None:
|
||||
return val
|
||||
from modules import shared
|
||||
return getattr(shared.opts, key, None)
|
||||
|
||||
|
||||
# global variables to keep track of changing image size in multiple passes
|
||||
|
|
@ -178,17 +184,17 @@ def split_attention(layer: nn.Module, tile_size: int=256, min_tile_size: int=128
|
|||
|
||||
def context_hypertile_vae(p):
|
||||
from modules import shared
|
||||
if shared.sd_model is None or not shared.opts.hypertile_vae_enabled:
|
||||
if p.sd_model is None or not _p_or_opt(p, 'hypertile_vae_enabled'):
|
||||
return nullcontext()
|
||||
if shared.opts.cross_attention_optimization == 'Sub-quadratic':
|
||||
log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
|
||||
shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
|
||||
return nullcontext()
|
||||
global max_h, max_w, error_reported # pylint: disable=global-statement
|
||||
error_reported = False
|
||||
error_reported = False
|
||||
set_resolution(p)
|
||||
max_h, max_w = 0, 0
|
||||
vae = getattr(shared.sd_model, "vae", None)
|
||||
vae = getattr(p.sd_model, "vae", None)
|
||||
if height == 0 or width == 0:
|
||||
log.warning('Hypertile VAE disabled: resolution unknown')
|
||||
return nullcontext()
|
||||
|
|
@ -198,25 +204,27 @@ def context_hypertile_vae(p):
|
|||
if vae is None:
|
||||
return nullcontext()
|
||||
else:
|
||||
tile_size = shared.opts.hypertile_vae_tile if shared.opts.hypertile_vae_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
|
||||
min_tile_size = shared.opts.hypertile_unet_min_tile if shared.opts.hypertile_unet_min_tile > 0 else 128
|
||||
log.info(f'Applying HyperTile: vae={min_tile_size}/{tile_size}')
|
||||
_vae_tile = _p_or_opt(p, 'hypertile_vae_tile')
|
||||
tile_size = _vae_tile if _vae_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
|
||||
_min_tile = _p_or_opt(p, 'hypertile_unet_min_tile')
|
||||
min_tile_size = _min_tile if _min_tile > 0 else 128
|
||||
shared.log.info(f'Applying HyperTile: vae={min_tile_size}/{tile_size}')
|
||||
p.extra_generation_params['Hypertile VAE'] = tile_size
|
||||
return split_attention(vae, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=shared.opts.hypertile_vae_swap_size)
|
||||
return split_attention(vae, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=_p_or_opt(p, 'hypertile_vae_swap_size'))
|
||||
|
||||
|
||||
def context_hypertile_unet(p):
|
||||
from modules import shared
|
||||
if shared.sd_model is None or not shared.opts.hypertile_unet_enabled:
|
||||
if p.sd_model is None or not _p_or_opt(p, 'hypertile_unet_enabled'):
|
||||
return nullcontext()
|
||||
if shared.opts.cross_attention_optimization == 'Sub-quadratic' and not shared.cmd_opts.experimental:
|
||||
log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
|
||||
shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
|
||||
return nullcontext()
|
||||
global max_h, max_w, error_reported # pylint: disable=global-statement
|
||||
error_reported = False
|
||||
set_resolution(p)
|
||||
max_h, max_w = 0, 0
|
||||
unet = getattr(shared.sd_model, "unet", None)
|
||||
unet = getattr(p.sd_model, "unet", None)
|
||||
if height == 0 or width == 0:
|
||||
log.warning('Hypertile VAE disabled: resolution unknown')
|
||||
return nullcontext()
|
||||
|
|
@ -224,24 +232,25 @@ def context_hypertile_unet(p):
|
|||
log.warning(f'Hypertile UNet disabled: width={width} height={height} are not divisible by 8')
|
||||
return nullcontext()
|
||||
if unet is None:
|
||||
# log.warning('Hypertile UNet is enabled but no Unet model was found')
|
||||
# shared.log.warning('Hypertile UNet is enabled but no Unet model was found')
|
||||
return nullcontext()
|
||||
else:
|
||||
tile_size = shared.opts.hypertile_unet_tile if shared.opts.hypertile_unet_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
|
||||
min_tile_size = shared.opts.hypertile_unet_min_tile if shared.opts.hypertile_unet_min_tile > 0 else 128
|
||||
log.info(f'Applying HyperTile: unet={min_tile_size}/{tile_size}')
|
||||
_unet_tile = _p_or_opt(p, 'hypertile_unet_tile')
|
||||
tile_size = _unet_tile if _unet_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
|
||||
_min_tile = _p_or_opt(p, 'hypertile_unet_min_tile')
|
||||
min_tile_size = _min_tile if _min_tile > 0 else 128
|
||||
shared.log.info(f'Applying HyperTile: unet={min_tile_size}/{tile_size}')
|
||||
p.extra_generation_params['Hypertile UNet'] = tile_size
|
||||
return split_attention(unet, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=shared.opts.hypertile_unet_swap_size, depth=shared.opts.hypertile_unet_depth)
|
||||
return split_attention(unet, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=_p_or_opt(p, 'hypertile_unet_swap_size'), depth=_p_or_opt(p, 'hypertile_unet_depth'))
|
||||
|
||||
|
||||
def hypertile_set(p, hr=False):
|
||||
from modules import shared
|
||||
global error_reported, reset_needed, skip_hypertile # pylint: disable=global-statement
|
||||
if not shared.opts.hypertile_unet_enabled:
|
||||
if not _p_or_opt(p, 'hypertile_unet_enabled'):
|
||||
return
|
||||
error_reported = False
|
||||
set_resolution(p, hr=hr)
|
||||
skip_hypertile = shared.opts.hypertile_hires_only and not getattr(p, 'is_hr_pass', False)
|
||||
skip_hypertile = _p_or_opt(p, 'hypertile_hires_only') and not getattr(p, 'is_hr_pass', False)
|
||||
reset_needed = True
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue