From 5201d457e3bd3d31a57bfaebe2e4abb898618034 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Thu, 15 Dec 2022 13:45:10 -0600 Subject: [PATCH] Updates --- .gitignore | 2 +- configs/v1-inference.yaml | 10 +++--- configs/v2-inference-v.yaml | 10 +++--- configs/v2-inference.yaml | 10 +++--- scripts/patch_fixer.py | 72 +++++++++++++++++++++++++++++++++---- 5 files changed, 82 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index d9005f2..b441b2f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ __pycache__/ *.py[cod] *$py.class - +.idea/* # C extensions *.so diff --git a/configs/v1-inference.yaml b/configs/v1-inference.yaml index e04b839..15de0d4 100644 --- a/configs/v1-inference.yaml +++ b/configs/v1-inference.yaml @@ -56,12 +56,12 @@ model: out_ch: 3 ch: 128 ch_mult: - - 1 - - 2 - - 4 - - 4 + - 1 + - 2 + - 4 + - 4 num_res_blocks: 2 - attn_resolutions: [] + attn_resolutions: [ ] dropout: 0.0 lossconfig: target: torch.nn.Identity diff --git a/configs/v2-inference-v.yaml b/configs/v2-inference-v.yaml index 513cd63..bd0f898 100644 --- a/configs/v2-inference-v.yaml +++ b/configs/v2-inference-v.yaml @@ -51,12 +51,12 @@ model: out_ch: 3 ch: 128 ch_mult: - - 1 - - 2 - - 4 - - 4 + - 1 + - 2 + - 4 + - 4 num_res_blocks: 2 - attn_resolutions: [] + attn_resolutions: [ ] dropout: 0.0 lossconfig: target: torch.nn.Identity diff --git a/configs/v2-inference.yaml b/configs/v2-inference.yaml index 0eb2539..4b07468 100644 --- a/configs/v2-inference.yaml +++ b/configs/v2-inference.yaml @@ -50,12 +50,12 @@ model: out_ch: 3 ch: 128 ch_mult: - - 1 - - 2 - - 4 - - 4 + - 1 + - 2 + - 4 + - 4 num_res_blocks: 2 - attn_resolutions: [] + attn_resolutions: [ ] dropout: 0.0 lossconfig: target: torch.nn.Identity diff --git a/scripts/patch_fixer.py b/scripts/patch_fixer.py index 8714704..c35d728 100644 --- a/scripts/patch_fixer.py +++ b/scripts/patch_fixer.py @@ -4,9 +4,12 @@ import os.path import traceback from collections import namedtuple +import ldm.modules.attention import torch +from einops import rearrange, repeat from ldm.util import instantiate_from_config from omegaconf import OmegaConf +from torch import einsum from modules import shared, devices, script_callbacks, sd_models from modules.paths import models_path @@ -21,11 +24,56 @@ CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'mod checkpoints_list = {} checkpoints_loaded = collections.OrderedDict() +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def fixed_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION == "fp32": + with torch.autocast(enabled=False, device_type='cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + def get_config(checkpoint_info): path = checkpoint_info[0] model_config = checkpoint_info.config checkpoint_dir = os.path.join(shared.script_path, "extensions", "sd_auto_fix", "configs") + is_v21 = False if model_config == shared.cmd_opts.config: try: checkpoint = torch.load(path) @@ -38,18 +86,24 @@ def get_config(checkpoint_info): else: model_config = os.path.join(checkpoint_dir, "v2-inference-v.yaml") print(f"V2 Model detected, selecting model config: {model_config}") + v21_keys = ["callbacks", "lr_schedulers", "native_amp_scaling_state"] + for v22key in v21_keys: + if v22key in checkpoint: + is_v21 = True + break + del checkpoint except Exception as e: print(f"Exception: {e}") traceback.print_exc() pass - return model_config + return model_config, is_v21 def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() - model_config = get_config(checkpoint_info) + model_config, is_v21 = get_config(checkpoint_info) if model_config != shared.cmd_opts.config: print(f"Loading config from: {model_config}") @@ -74,9 +128,6 @@ def load_model(checkpoint_info=None): do_inpainting_hijack() - if shared.cmd_opts.no_half: - sd_config.model.params.unet_config.params.use_fp16 = False - sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) @@ -87,6 +138,10 @@ def load_model(checkpoint_info=None): sd_hijack.model_hijack.hijack(sd_model) + if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half: + print("Fixing attention for v21 model.") + ldm.modules.attention.CrossAttention.forward = fixed_forward + sd_model.eval() shared.sd_model = sd_model @@ -106,7 +161,7 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - model_config = get_config(checkpoint_info) + model_config, is_v21 = get_config(checkpoint_info) checkpoint_info = checkpoint_info._replace(config=model_config) if sd_model.sd_checkpoint_info.config != model_config or should_hijack_inpainting( checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): @@ -125,6 +180,11 @@ def reload_model_weights(sd_model=None, info=None): load_model_weights(sd_model, checkpoint_info) sd_hijack.model_hijack.hijack(sd_model) + + if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half: + print("Fixing attention for v21 model.") + ldm.modules.attention.CrossAttention.forward = fixed_forward + script_callbacks.model_loaded_callback(sd_model) if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: