Updates
parent
56f05dd638
commit
5201d457e3
|
|
@ -2,7 +2,7 @@
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
.idea/*
|
||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,12 +56,12 @@ model:
|
||||||
out_ch: 3
|
out_ch: 3
|
||||||
ch: 128
|
ch: 128
|
||||||
ch_mult:
|
ch_mult:
|
||||||
- 1
|
- 1
|
||||||
- 2
|
- 2
|
||||||
- 4
|
- 4
|
||||||
- 4
|
- 4
|
||||||
num_res_blocks: 2
|
num_res_blocks: 2
|
||||||
attn_resolutions: []
|
attn_resolutions: [ ]
|
||||||
dropout: 0.0
|
dropout: 0.0
|
||||||
lossconfig:
|
lossconfig:
|
||||||
target: torch.nn.Identity
|
target: torch.nn.Identity
|
||||||
|
|
|
||||||
|
|
@ -51,12 +51,12 @@ model:
|
||||||
out_ch: 3
|
out_ch: 3
|
||||||
ch: 128
|
ch: 128
|
||||||
ch_mult:
|
ch_mult:
|
||||||
- 1
|
- 1
|
||||||
- 2
|
- 2
|
||||||
- 4
|
- 4
|
||||||
- 4
|
- 4
|
||||||
num_res_blocks: 2
|
num_res_blocks: 2
|
||||||
attn_resolutions: []
|
attn_resolutions: [ ]
|
||||||
dropout: 0.0
|
dropout: 0.0
|
||||||
lossconfig:
|
lossconfig:
|
||||||
target: torch.nn.Identity
|
target: torch.nn.Identity
|
||||||
|
|
|
||||||
|
|
@ -50,12 +50,12 @@ model:
|
||||||
out_ch: 3
|
out_ch: 3
|
||||||
ch: 128
|
ch: 128
|
||||||
ch_mult:
|
ch_mult:
|
||||||
- 1
|
- 1
|
||||||
- 2
|
- 2
|
||||||
- 4
|
- 4
|
||||||
- 4
|
- 4
|
||||||
num_res_blocks: 2
|
num_res_blocks: 2
|
||||||
attn_resolutions: []
|
attn_resolutions: [ ]
|
||||||
dropout: 0.0
|
dropout: 0.0
|
||||||
lossconfig:
|
lossconfig:
|
||||||
target: torch.nn.Identity
|
target: torch.nn.Identity
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,12 @@ import os.path
|
||||||
import traceback
|
import traceback
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
|
import ldm.modules.attention
|
||||||
import torch
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
from torch import einsum
|
||||||
|
|
||||||
from modules import shared, devices, script_callbacks, sd_models
|
from modules import shared, devices, script_callbacks, sd_models
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
|
|
@ -21,11 +24,56 @@ CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'mod
|
||||||
checkpoints_list = {}
|
checkpoints_list = {}
|
||||||
checkpoints_loaded = collections.OrderedDict()
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
|
||||||
|
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
||||||
|
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
|
||||||
|
def fixed_forward(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
|
# force cast to fp32 to avoid overflowing
|
||||||
|
if _ATTN_PRECISION == "fp32":
|
||||||
|
with torch.autocast(enabled=False, device_type='cuda'):
|
||||||
|
q, k = q.float(), k.float()
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
else:
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
|
||||||
|
del q, k
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
mask = rearrange(mask, 'b ... -> b (...)')
|
||||||
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||||
|
sim.masked_fill_(~mask, max_neg_value)
|
||||||
|
|
||||||
|
# attention, what we cannot get enough of
|
||||||
|
sim = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = einsum('b i j, b j d -> b i d', sim, v)
|
||||||
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
def get_config(checkpoint_info):
|
def get_config(checkpoint_info):
|
||||||
path = checkpoint_info[0]
|
path = checkpoint_info[0]
|
||||||
model_config = checkpoint_info.config
|
model_config = checkpoint_info.config
|
||||||
checkpoint_dir = os.path.join(shared.script_path, "extensions", "sd_auto_fix", "configs")
|
checkpoint_dir = os.path.join(shared.script_path, "extensions", "sd_auto_fix", "configs")
|
||||||
|
is_v21 = False
|
||||||
if model_config == shared.cmd_opts.config:
|
if model_config == shared.cmd_opts.config:
|
||||||
try:
|
try:
|
||||||
checkpoint = torch.load(path)
|
checkpoint = torch.load(path)
|
||||||
|
|
@ -38,18 +86,24 @@ def get_config(checkpoint_info):
|
||||||
else:
|
else:
|
||||||
model_config = os.path.join(checkpoint_dir, "v2-inference-v.yaml")
|
model_config = os.path.join(checkpoint_dir, "v2-inference-v.yaml")
|
||||||
print(f"V2 Model detected, selecting model config: {model_config}")
|
print(f"V2 Model detected, selecting model config: {model_config}")
|
||||||
|
v21_keys = ["callbacks", "lr_schedulers", "native_amp_scaling_state"]
|
||||||
|
for v22key in v21_keys:
|
||||||
|
if v22key in checkpoint:
|
||||||
|
is_v21 = True
|
||||||
|
break
|
||||||
|
|
||||||
del checkpoint
|
del checkpoint
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Exception: {e}")
|
print(f"Exception: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pass
|
pass
|
||||||
return model_config
|
return model_config, is_v21
|
||||||
|
|
||||||
|
|
||||||
def load_model(checkpoint_info=None):
|
def load_model(checkpoint_info=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
model_config = get_config(checkpoint_info)
|
model_config, is_v21 = get_config(checkpoint_info)
|
||||||
|
|
||||||
if model_config != shared.cmd_opts.config:
|
if model_config != shared.cmd_opts.config:
|
||||||
print(f"Loading config from: {model_config}")
|
print(f"Loading config from: {model_config}")
|
||||||
|
|
@ -74,9 +128,6 @@ def load_model(checkpoint_info=None):
|
||||||
|
|
||||||
do_inpainting_hijack()
|
do_inpainting_hijack()
|
||||||
|
|
||||||
if shared.cmd_opts.no_half:
|
|
||||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
|
||||||
|
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
load_model_weights(sd_model, checkpoint_info)
|
load_model_weights(sd_model, checkpoint_info)
|
||||||
|
|
||||||
|
|
@ -87,6 +138,10 @@ def load_model(checkpoint_info=None):
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
|
||||||
|
if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half:
|
||||||
|
print("Fixing attention for v21 model.")
|
||||||
|
ldm.modules.attention.CrossAttention.forward = fixed_forward
|
||||||
|
|
||||||
sd_model.eval()
|
sd_model.eval()
|
||||||
shared.sd_model = sd_model
|
shared.sd_model = sd_model
|
||||||
|
|
||||||
|
|
@ -106,7 +161,7 @@ def reload_model_weights(sd_model=None, info=None):
|
||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
model_config = get_config(checkpoint_info)
|
model_config, is_v21 = get_config(checkpoint_info)
|
||||||
checkpoint_info = checkpoint_info._replace(config=model_config)
|
checkpoint_info = checkpoint_info._replace(config=model_config)
|
||||||
if sd_model.sd_checkpoint_info.config != model_config or should_hijack_inpainting(
|
if sd_model.sd_checkpoint_info.config != model_config or should_hijack_inpainting(
|
||||||
checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||||
|
|
@ -125,6 +180,11 @@ def reload_model_weights(sd_model=None, info=None):
|
||||||
load_model_weights(sd_model, checkpoint_info)
|
load_model_weights(sd_model, checkpoint_info)
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
|
||||||
|
if is_v21 and not shared.cmd_opts.xformers and not shared.cmd_opts.force_enable_xformers and not shared.cmd_opts.no_half:
|
||||||
|
print("Fixing attention for v21 model.")
|
||||||
|
ldm.modules.attention.CrossAttention.forward = fixed_forward
|
||||||
|
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
|
||||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue