v0.6.4 (fix SDXL compatibility)

pull/38/head
dvruette 2024-03-07 12:48:55 +01:00
parent 5a247c9d9e
commit d6b295e978
4 changed files with 87 additions and 27 deletions

View File

@ -27,7 +27,7 @@ except ImportError:
from modules.ui import create_refresh_button
__version__ = "0.6.3"
__version__ = "0.6.4"
DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1")

View File

@ -48,12 +48,20 @@ def mark_prompt_context(x, positive):
x.schedules = mark_prompt_context(x.schedules, positive)
return x
if isinstance(x, ScheduledPromptConditioning):
cond = x.cond
if prompt_context_is_marked(cond):
return x
mark = POSITIVE_MARK_TOKEN if positive else NEGATIVE_MARK_TOKEN
cond = torch.cat([torch.zeros_like(cond)[:1] + mark, cond], dim=0)
return ScheduledPromptConditioning(end_at_step=x.end_at_step, cond=cond)
if isinstance(x.cond, dict):
cond = x.cond['crossattn']
if prompt_context_is_marked(cond):
return x
mark = POSITIVE_MARK_TOKEN if positive else NEGATIVE_MARK_TOKEN
cond = torch.cat([torch.zeros_like(cond)[:1] + mark, cond], dim=0)
return ScheduledPromptConditioning(end_at_step=x.end_at_step, cond=dict(crossattn=cond, vector=x.cond['vector']))
else:
cond = x.cond
if prompt_context_is_marked(cond):
return x
mark = POSITIVE_MARK_TOKEN if positive else NEGATIVE_MARK_TOKEN
cond = torch.cat([torch.zeros_like(cond)[:1] + mark, cond], dim=0)
return ScheduledPromptConditioning(end_at_step=x.end_at_step, cond=cond)
return x
@ -67,21 +75,22 @@ def unmark_prompt_context(x):
# After you mark the prompts, the ControlNet will know which prompt is cond/uncond and works as expected.
# After you mark the prompts, the mismatch errors will disappear.
mark_batch = torch.ones(size=(x.shape[0], 1, 1, 1), dtype=x.dtype, device=x.device)
uc_indices = []
context = x
return mark_batch, uc_indices, context
return mark_batch, [], [], context
mark = x[:, 0, :]
context = x[:, 1:, :]
mark = torch.mean(torch.abs(mark - NEGATIVE_MARK_TOKEN), dim=1)
mark = (mark > MARK_EPS).float()
mark_batch = mark[:, None, None, None].to(x.dtype).to(x.device)
uc_indices = mark.detach().cpu().numpy().tolist()
uc_indices = [i for i, item in enumerate(uc_indices) if item < 0.5]
mark = mark.detach().cpu().numpy().tolist()
uc_indices = [i for i, item in enumerate(mark) if item < 0.5]
c_indices = [i for i, item in enumerate(mark) if not item < 0.5]
StableDiffusionProcessing.cached_c = [None, None]
StableDiffusionProcessing.cached_uc = [None, None]
return mark_batch, uc_indices, context
return mark_batch, uc_indices, c_indices, context
def apply_marking_patch(process):

View File

@ -6,7 +6,11 @@ import torchvision.transforms.functional as functional
from modules import devices, images, shared
from modules.processing import StableDiffusionProcessingTxt2Img
from ldm.modules.attention import BasicTransformerBlock
import ldm.modules.attention
import sgm.modules.attention
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.models.diffusion.ddpm import extract_into_tensor
from sgm.models.diffusion import DiffusionEngine
from scripts.marking import apply_marking_patch, unmark_prompt_context
from scripts.fabric_utils import image_hash
@ -14,6 +18,10 @@ from scripts.weighted_attention import weighted_attention
from scripts.merging import compute_merge
SD15 = "sd15"
SDXL = "sdxl"
def encode_to_latent(p, image, w, h):
image = images.resize_image(1, image, w, h)
x = functional.pil_to_tensor(image)
@ -24,9 +32,24 @@ def encode_to_latent(p, image, w, h):
# TODO: use caching to make this faster
with devices.autocast():
vae_output = p.sd_model.encode_first_stage(x)
if torch.isnan(vae_output).any():
print(f"[FABRIC] NaNs in VAE output found, retrying with 32-bit precision. To always start with 32-bit VAE, use --no-half-vae commandline flag.")
devices.dtype_vae = torch.float32
x = x.to(devices.dtype_vae)
p.sd_model.first_stage_model.to(devices.dtype_vae)
vae_output = p.sd_model.encode_first_stage(x)
z = p.sd_model.get_first_stage_encoding(vae_output)
z = z.to(devices.dtype_unet)
return z.squeeze(0)
def forward_noise(p, x_0, t, noise=None):
if noise is None:
noise = torch.randn_like(x_0)
sqrt_alpha_bar_t = extract_into_tensor(p.sd_model.alphas_cumprod.sqrt(), t, x_0.shape)
sqrt_one_minus_alpha_bar_t = extract_into_tensor((1.0 - p.sd_model.alphas_cumprod).sqrt(), t, x_0.shape)
x_t = sqrt_alpha_bar_t * x_0 + sqrt_one_minus_alpha_bar_t * noise
return x_t
def get_latents_from_params(p, params, width, height):
w, h = (width // 8) * 8, (height // 8) * 8
@ -53,8 +76,8 @@ def get_latents_from_params(p, params, width, height):
return params.pos_latents, params.neg_latents
def get_curr_feedback_weight(p, params, timestep):
progress = 1 - (timestep / (p.sd_model.num_timesteps - 1))
def get_curr_feedback_weight(p, params, timestep, num_timesteps=1000):
progress = 1 - (timestep / (num_timesteps - 1))
if progress >= params.start and progress <= params.end:
w = params.max_weight
else:
@ -70,9 +93,28 @@ def patch_unet_forward_pass(p, unet, params):
if not hasattr(unet, "_fabric_old_forward"):
unet._fabric_old_forward = unet.forward
if isinstance(p.sd_model, LatentDiffusion):
sd_version = SD15
num_timesteps = p.sd_model.num_timesteps
BasicTransformerBlock = ldm.modules.attention.BasicTransformerBlock
elif isinstance(p.sd_model, DiffusionEngine):
sd_version = SDXL
num_timesteps = len(p.sd_model.alphas_cumprod)
BasicTransformerBlock = sgm.modules.attention.BasicTransformerBlock
else:
raise ValueError(f"[FABRIC] Unsupported SD model: {type(p.sd_model)}")
batch_size = p.batch_size
null_ctx = p.sd_model.get_learned_conditioning([""]).to(devices.device, dtype=devices.dtype_unet)
null_ctx = p.sd_model.get_learned_conditioning([""])
if isinstance(null_ctx, torch.Tensor): # SD1.5
null_ctx = null_ctx.to(devices.device, dtype=devices.dtype_unet)
elif isinstance(null_ctx, dict): # SDXL
for key in null_ctx:
if isinstance(null_ctx[key], torch.Tensor):
null_ctx[key] = null_ctx[key].to(devices.device, dtype=devices.dtype_unet)
else:
raise ValueError(f"[FABRIC] Unsupported context type: {type(null_ctx)}")
width = (p.width // 8) * 8
height = (p.height // 8) * 8
@ -99,8 +141,7 @@ def patch_unet_forward_pass(p, unet, params):
}
def new_forward(self, x, timesteps=None, context=None, **kwargs):
_, uncond_ids, context = unmark_prompt_context(context)
cond_ids = [i for i in range(context.size(0)) if i not in uncond_ids]
_, uncond_ids, cond_ids, context = unmark_prompt_context(context)
has_cond = len(cond_ids) > 0
has_uncond = len(uncond_ids) > 0
@ -111,7 +152,7 @@ def patch_unet_forward_pass(p, unet, params):
print("[FABRIC] Skipping feedback during high-res fix")
return self._fabric_old_forward(x, timesteps, context, **kwargs)
pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item())
pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item(), num_timesteps=num_timesteps)
if pos_weight <= 0 and neg_weight <= 0:
return self._fabric_old_forward(x, timesteps, context, **kwargs)
@ -132,11 +173,16 @@ def patch_unet_forward_pass(p, unet, params):
return self._fabric_old_forward(x, timesteps, context, **kwargs)
# add noise to reference latents
all_zs = []
for latent in all_latents:
z = p.sd_model.q_sample(latent.unsqueeze(0), torch.round(timesteps.float()).long())[0]
all_zs.append(z)
all_zs = torch.stack(all_zs, dim=0)
if sd_version == SD15:
all_zs = []
for latent in all_latents:
z = p.sd_model.q_sample(latent.unsqueeze(0), timesteps[0].unsqueeze(0))[0]
all_zs.append(z)
all_zs = torch.stack(all_zs, dim=0)
else: # SDXL
xs_0 = torch.stack(all_latents, dim=0)
ts = timesteps[0, None].expand(xs_0.size(0)) # (bs,)
all_zs = forward_noise(p, xs_0, torch.round(ts.float()).long())
# save original forward pass
for module in self.modules():
@ -178,8 +224,13 @@ def patch_unet_forward_pass(p, unet, params):
zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype)
ts = timesteps[:1].expand(zs.size(0)) # (bs,)
# use the null prompt for pre-computing hidden states on feedback images
ctx = null_ctx.expand(zs.size(0), -1, -1) # (bs, p_seq, p_dim)
_ = self._fabric_old_forward(zs, ts, ctx)
ctx_args = {}
if sd_version == SD15:
ctx_args["context"] = null_ctx.expand(zs.size(0), -1, -1) # (bs, seq_len, d_model)
else: # SDXL
ctx_args["context"] = null_ctx["crossattn"].expand(zs.size(0), -1, -1) # (bs, seq_len, d_model)
ctx_args["y"] = null_ctx["vector"].expand(zs.size(0), -1) # (bs, d_vector)
_ = self._fabric_old_forward(zs, ts, **ctx_args)
num_pos = len(pos_latents)
num_neg = len(neg_latents)

View File

@ -59,7 +59,7 @@ def weighted_attention(self, attn_fn, x, context=None, weights=None, **kwargs):
return attn_fn(x, context=context, **kwargs)
weighted_attn_fn = get_weighted_attn_fn()
return weighted_attn_fn(self, x, context=context, weights=weights, **kwargs)
return weighted_attn_fn(self, x, context=context, weights=weights, mask=kwargs.get('mask', None))
def _get_attn_bias(weights, shape=None, dtype=torch.float32):