merge: modules/sd_hijack_hypertile.py

pull/4690/head
vladmandic 2026-03-13 14:36:25 +01:00
parent cb224651a0
commit 35803746df
1 changed files with 31 additions and 22 deletions

View File

@ -2,7 +2,7 @@
# based on: https://github.com/tfernd/HyperTile/tree/main/hyper_tile/utils.py + https://github.com/tfernd/HyperTile/tree/main/hyper_tile/hyper_tile.py
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Callable
from functools import wraps, cache
from contextlib import contextmanager, nullcontext
import random
@ -10,9 +10,15 @@ import math
import torch
import torch.nn as nn
from einops import rearrange
from modules.logger import log
if TYPE_CHECKING:
from collections.abc import Callable
from installer import log
def _p_or_opt(p, key):
val = getattr(p, key, None)
if val is not None:
return val
from modules import shared
return getattr(shared.opts, key, None)
# global variables to keep track of changing image size in multiple passes
@ -178,17 +184,17 @@ def split_attention(layer: nn.Module, tile_size: int=256, min_tile_size: int=128
def context_hypertile_vae(p):
from modules import shared
if shared.sd_model is None or not shared.opts.hypertile_vae_enabled:
if p.sd_model is None or not _p_or_opt(p, 'hypertile_vae_enabled'):
return nullcontext()
if shared.opts.cross_attention_optimization == 'Sub-quadratic':
log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
return nullcontext()
global max_h, max_w, error_reported # pylint: disable=global-statement
error_reported = False
error_reported = False
set_resolution(p)
max_h, max_w = 0, 0
vae = getattr(shared.sd_model, "vae", None)
vae = getattr(p.sd_model, "vae", None)
if height == 0 or width == 0:
log.warning('Hypertile VAE disabled: resolution unknown')
return nullcontext()
@ -198,25 +204,27 @@ def context_hypertile_vae(p):
if vae is None:
return nullcontext()
else:
tile_size = shared.opts.hypertile_vae_tile if shared.opts.hypertile_vae_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
min_tile_size = shared.opts.hypertile_unet_min_tile if shared.opts.hypertile_unet_min_tile > 0 else 128
log.info(f'Applying HyperTile: vae={min_tile_size}/{tile_size}')
_vae_tile = _p_or_opt(p, 'hypertile_vae_tile')
tile_size = _vae_tile if _vae_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
_min_tile = _p_or_opt(p, 'hypertile_unet_min_tile')
min_tile_size = _min_tile if _min_tile > 0 else 128
shared.log.info(f'Applying HyperTile: vae={min_tile_size}/{tile_size}')
p.extra_generation_params['Hypertile VAE'] = tile_size
return split_attention(vae, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=shared.opts.hypertile_vae_swap_size)
return split_attention(vae, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=_p_or_opt(p, 'hypertile_vae_swap_size'))
def context_hypertile_unet(p):
from modules import shared
if shared.sd_model is None or not shared.opts.hypertile_unet_enabled:
if p.sd_model is None or not _p_or_opt(p, 'hypertile_unet_enabled'):
return nullcontext()
if shared.opts.cross_attention_optimization == 'Sub-quadratic' and not shared.cmd_opts.experimental:
log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
shared.log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
return nullcontext()
global max_h, max_w, error_reported # pylint: disable=global-statement
error_reported = False
set_resolution(p)
max_h, max_w = 0, 0
unet = getattr(shared.sd_model, "unet", None)
unet = getattr(p.sd_model, "unet", None)
if height == 0 or width == 0:
log.warning('Hypertile VAE disabled: resolution unknown')
return nullcontext()
@ -224,24 +232,25 @@ def context_hypertile_unet(p):
log.warning(f'Hypertile UNet disabled: width={width} height={height} are not divisible by 8')
return nullcontext()
if unet is None:
# log.warning('Hypertile UNet is enabled but no Unet model was found')
# shared.log.warning('Hypertile UNet is enabled but no Unet model was found')
return nullcontext()
else:
tile_size = shared.opts.hypertile_unet_tile if shared.opts.hypertile_unet_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
min_tile_size = shared.opts.hypertile_unet_min_tile if shared.opts.hypertile_unet_min_tile > 0 else 128
log.info(f'Applying HyperTile: unet={min_tile_size}/{tile_size}')
_unet_tile = _p_or_opt(p, 'hypertile_unet_tile')
tile_size = _unet_tile if _unet_tile > 0 else max(128, 64 * min(p.width // 128, p.height // 128))
_min_tile = _p_or_opt(p, 'hypertile_unet_min_tile')
min_tile_size = _min_tile if _min_tile > 0 else 128
shared.log.info(f'Applying HyperTile: unet={min_tile_size}/{tile_size}')
p.extra_generation_params['Hypertile UNet'] = tile_size
return split_attention(unet, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=shared.opts.hypertile_unet_swap_size, depth=shared.opts.hypertile_unet_depth)
return split_attention(unet, tile_size=tile_size, min_tile_size=min_tile_size, swap_size=_p_or_opt(p, 'hypertile_unet_swap_size'), depth=_p_or_opt(p, 'hypertile_unet_depth'))
def hypertile_set(p, hr=False):
from modules import shared
global error_reported, reset_needed, skip_hypertile # pylint: disable=global-statement
if not shared.opts.hypertile_unet_enabled:
if not _p_or_opt(p, 'hypertile_unet_enabled'):
return
error_reported = False
set_resolution(p, hr=hr)
skip_hypertile = shared.opts.hypertile_hires_only and not getattr(p, 'is_hr_pass', False)
skip_hypertile = _p_or_opt(p, 'hypertile_hires_only') and not getattr(p, 'is_hr_pass', False)
reset_needed = True