diff --git a/modules/sd_hijack_hypertile.py b/modules/sd_hijack_hypertile.py index 6bf01d5ac..4f67d6c10 100644 --- a/modules/sd_hijack_hypertile.py +++ b/modules/sd_hijack_hypertile.py @@ -2,7 +2,7 @@ # based on: https://github.com/tfernd/HyperTile/tree/main/hyper_tile/utils.py + https://github.com/tfernd/HyperTile/tree/main/hyper_tile/hyper_tile.py from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Callable from functools import wraps, cache from contextlib import contextmanager, nullcontext import random @@ -10,9 +10,15 @@ import math import torch import torch.nn as nn from einops import rearrange -from modules.logger import log -if TYPE_CHECKING: - from collections.abc import Callable +from installer import log + + +def _p_or_opt(p, key): + val = getattr(p, key, None) + if val is not None: + return val + from modules import shared + return getattr(shared.opts, key, None) # global variables to keep track of changing image size in multiple passes @@ -178,17 +184,17 @@ def split_attention(layer: nn.Module, tile_size: int=256, min_tile_size: int=128 def context_hypertile_vae(p): from modules import shared - if shared.sd_model is None or not shared.opts.hypertile_vae_enabled: + if p.sd_model is None or not _p_or_opt(p, 'hypertile_vae_enabled'): return nullcontext() if shared.opts.cross_attention_optimization == 'Sub-quadratic': - log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization') + shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization') return nullcontext() global max_h, max_w, error_reported # pylint: disable=global-statement error_reported = False error_reported = False set_resolution(p) max_h, max_w = 0, 0 - vae = getattr(shared.sd_model, "vae", None) + vae = getattr(p.sd_model, "vae", None) if height == 0 or width == 0: log.warning('Hypertile VAE disabled: resolution unknown') return nullcontext() @@ -198,25 +204,27 @@ def context_hypertile_vae(p): if vae is None: return nullcontext() else: - tile_size = shared.opts.hypertile_vae_tile if shared.opts.hypertile_vae_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128)) - min_tile_size = shared.opts.hypertile_unet_min_tile if shared.opts.hypertile_unet_min_tile > 0 else 128 - log.info(f'Applying HyperTile: vae={min_tile_size}/{tile_size}') + _vae_tile = _p_or_opt(p, 'hypertile_vae_tile') + tile_size = _vae_tile if _vae_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128)) + _min_tile = _p_or_opt(p, 'hypertile_unet_min_tile') + min_tile_size = _min_tile if _min_tile > 0 else 128 + shared.log.info(f'Applying HyperTile: vae={min_tile_size}/{tile_size}') p.extra_generation_params['Hypertile VAE'] = tile_size - return split_attention(vae, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=shared.opts.hypertile_vae_swap_size) + return split_attention(vae, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=_p_or_opt(p, 'hypertile_vae_swap_size')) def context_hypertile_unet(p): from modules import shared - if shared.sd_model is None or not shared.opts.hypertile_unet_enabled: + if p.sd_model is None or not _p_or_opt(p, 'hypertile_unet_enabled'): return nullcontext() if shared.opts.cross_attention_optimization == 'Sub-quadratic' and not shared.cmd_opts.experimental: - log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization') + shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization') return nullcontext() global max_h, max_w, error_reported # pylint: disable=global-statement error_reported = False set_resolution(p) max_h, max_w = 0, 0 - unet = getattr(shared.sd_model, "unet", None) + unet = getattr(p.sd_model, "unet", None) if height == 0 or width == 0: log.warning('Hypertile VAE disabled: resolution unknown') return nullcontext() @@ -224,24 +232,25 @@ def context_hypertile_unet(p): log.warning(f'Hypertile UNet disabled: width={width} height={height} are not divisible by 8') return nullcontext() if unet is None: - # log.warning('Hypertile UNet is enabled but no Unet model was found') + # shared.log.warning('Hypertile UNet is enabled but no Unet model was found') return nullcontext() else: - tile_size = shared.opts.hypertile_unet_tile if shared.opts.hypertile_unet_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128)) - min_tile_size = shared.opts.hypertile_unet_min_tile if shared.opts.hypertile_unet_min_tile > 0 else 128 - log.info(f'Applying HyperTile: unet={min_tile_size}/{tile_size}') + _unet_tile = _p_or_opt(p, 'hypertile_unet_tile') + tile_size = _unet_tile if _unet_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128)) + _min_tile = _p_or_opt(p, 'hypertile_unet_min_tile') + min_tile_size = _min_tile if _min_tile > 0 else 128 + shared.log.info(f'Applying HyperTile: unet={min_tile_size}/{tile_size}') p.extra_generation_params['Hypertile UNet'] = tile_size - return split_attention(unet, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=shared.opts.hypertile_unet_swap_size, depth=shared.opts.hypertile_unet_depth) + return split_attention(unet, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=_p_or_opt(p, 'hypertile_unet_swap_size'), depth=_p_or_opt(p, 'hypertile_unet_depth')) def hypertile_set(p, hr=False): - from modules import shared global error_reported, reset_needed, skip_hypertile # pylint: disable=global-statement - if not shared.opts.hypertile_unet_enabled: + if not _p_or_opt(p, 'hypertile_unet_enabled'): return error_reported = False set_resolution(p, hr=hr) - skip_hypertile = shared.opts.hypertile_hires_only and not getattr(p, 'is_hr_pass', False) + skip_hypertile = _p_or_opt(p, 'hypertile_hires_only') and not getattr(p, 'is_hr_pass', False) reset_needed = True