automatic/modules/token_merge.py

90 lines
3.1 KiB
Python

from modules.logger import log
from modules import shared
def apply_token_merging(sd_model, p=None):
def _opt(key):
if p is not None:
val = getattr(p, key, None)
if val is not None:
return val
return getattr(shared.opts, key, None)
current_tome = getattr(sd_model, 'applied_tome', 0)
current_todo = getattr(sd_model, 'applied_todo', 0)
method = _opt('token_merging_method')
tome = _opt('tome_ratio')
todo = _opt('todo_ratio')
hypertile = _opt('hypertile_unet_enabled')
if method == 'ToMe' and tome > 0:
if current_tome == tome:
return
if hypertile and not shared.cmd_opts.experimental:
log.warning('Token merging not supported with HyperTile for UNet')
return
try:
import installer
installer.install('tomesd', 'tomesd', ignore=False)
import tomesd
tomesd.apply_patch(
sd_model,
ratio=tome,
use_rand=False, # can cause issues with some samplers
merge_attn=True,
merge_crossattn=False,
merge_mlp=False
)
log.info(f'Applying ToMe: ratio={tome}')
sd_model.applied_tome = tome
except Exception:
log.warning(f'Token merging not supported: pipeline={sd_model.__class__.__name__}')
else:
sd_model.applied_tome = 0
if method == 'ToDo' and todo > 0:
if current_todo == todo:
return
if hypertile and not shared.cmd_opts.experimental:
log.warning('Token merging not supported with HyperTile for UNet')
return
try:
from modules.todo.todo_utils import patch_attention_proc
token_merge_args = {
"ratio": todo,
"merge_tokens": "keys/values",
"merge_method": "downsample",
"downsample_method": "nearest",
"downsample_factor": 2,
"timestep_threshold_switch": 0.0,
"timestep_threshold_stop": 0.0,
"downsample_factor_level_2": 1,
"ratio_level_2": 0.0,
}
patch_attention_proc(sd_model.unet, token_merge_args=token_merge_args)
log.info(f'Applying ToDo: ratio={todo}')
sd_model.applied_todo = todo
except Exception:
log.warning(f'Token merging not supported: pipeline={sd_model.__class__.__name__}')
else:
sd_model.applied_todo = 0
def remove_token_merging(sd_model):
current_tome = getattr(sd_model, 'applied_tome', 0)
current_todo = getattr(sd_model, 'applied_todo', 0)
try:
if current_tome > 0:
import tomesd
tomesd.remove_patch(sd_model)
sd_model.applied_tome = 0
except Exception:
pass
try:
if current_todo > 0:
from modules.todo.todo_utils import remove_patch
remove_patch(sd_model)
sd_model.applied_todo = 0
except Exception:
pass