v0.6.4 (fix SDXL compatibility)
parent
5a247c9d9e
commit
d6b295e978
|
|
@ -27,7 +27,7 @@ except ImportError:
|
|||
from modules.ui import create_refresh_button
|
||||
|
||||
|
||||
__version__ = "0.6.3"
|
||||
__version__ = "0.6.4"
|
||||
|
||||
DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1")
|
||||
|
||||
|
|
|
|||
|
|
@ -48,12 +48,20 @@ def mark_prompt_context(x, positive):
|
|||
x.schedules = mark_prompt_context(x.schedules, positive)
|
||||
return x
|
||||
if isinstance(x, ScheduledPromptConditioning):
|
||||
cond = x.cond
|
||||
if prompt_context_is_marked(cond):
|
||||
return x
|
||||
mark = POSITIVE_MARK_TOKEN if positive else NEGATIVE_MARK_TOKEN
|
||||
cond = torch.cat([torch.zeros_like(cond)[:1] + mark, cond], dim=0)
|
||||
return ScheduledPromptConditioning(end_at_step=x.end_at_step, cond=cond)
|
||||
if isinstance(x.cond, dict):
|
||||
cond = x.cond['crossattn']
|
||||
if prompt_context_is_marked(cond):
|
||||
return x
|
||||
mark = POSITIVE_MARK_TOKEN if positive else NEGATIVE_MARK_TOKEN
|
||||
cond = torch.cat([torch.zeros_like(cond)[:1] + mark, cond], dim=0)
|
||||
return ScheduledPromptConditioning(end_at_step=x.end_at_step, cond=dict(crossattn=cond, vector=x.cond['vector']))
|
||||
else:
|
||||
cond = x.cond
|
||||
if prompt_context_is_marked(cond):
|
||||
return x
|
||||
mark = POSITIVE_MARK_TOKEN if positive else NEGATIVE_MARK_TOKEN
|
||||
cond = torch.cat([torch.zeros_like(cond)[:1] + mark, cond], dim=0)
|
||||
return ScheduledPromptConditioning(end_at_step=x.end_at_step, cond=cond)
|
||||
return x
|
||||
|
||||
|
||||
|
|
@ -67,21 +75,22 @@ def unmark_prompt_context(x):
|
|||
# After you mark the prompts, the ControlNet will know which prompt is cond/uncond and works as expected.
|
||||
# After you mark the prompts, the mismatch errors will disappear.
|
||||
mark_batch = torch.ones(size=(x.shape[0], 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||
uc_indices = []
|
||||
context = x
|
||||
return mark_batch, uc_indices, context
|
||||
return mark_batch, [], [], context
|
||||
mark = x[:, 0, :]
|
||||
context = x[:, 1:, :]
|
||||
mark = torch.mean(torch.abs(mark - NEGATIVE_MARK_TOKEN), dim=1)
|
||||
mark = (mark > MARK_EPS).float()
|
||||
mark_batch = mark[:, None, None, None].to(x.dtype).to(x.device)
|
||||
uc_indices = mark.detach().cpu().numpy().tolist()
|
||||
uc_indices = [i for i, item in enumerate(uc_indices) if item < 0.5]
|
||||
|
||||
mark = mark.detach().cpu().numpy().tolist()
|
||||
uc_indices = [i for i, item in enumerate(mark) if item < 0.5]
|
||||
c_indices = [i for i, item in enumerate(mark) if not item < 0.5]
|
||||
|
||||
StableDiffusionProcessing.cached_c = [None, None]
|
||||
StableDiffusionProcessing.cached_uc = [None, None]
|
||||
|
||||
return mark_batch, uc_indices, context
|
||||
return mark_batch, uc_indices, c_indices, context
|
||||
|
||||
|
||||
def apply_marking_patch(process):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,11 @@ import torchvision.transforms.functional as functional
|
|||
from modules import devices, images, shared
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img
|
||||
|
||||
from ldm.modules.attention import BasicTransformerBlock
|
||||
import ldm.modules.attention
|
||||
import sgm.modules.attention
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ldm.models.diffusion.ddpm import extract_into_tensor
|
||||
from sgm.models.diffusion import DiffusionEngine
|
||||
|
||||
from scripts.marking import apply_marking_patch, unmark_prompt_context
|
||||
from scripts.fabric_utils import image_hash
|
||||
|
|
@ -14,6 +18,10 @@ from scripts.weighted_attention import weighted_attention
|
|||
from scripts.merging import compute_merge
|
||||
|
||||
|
||||
SD15 = "sd15"
|
||||
SDXL = "sdxl"
|
||||
|
||||
|
||||
def encode_to_latent(p, image, w, h):
|
||||
image = images.resize_image(1, image, w, h)
|
||||
x = functional.pil_to_tensor(image)
|
||||
|
|
@ -24,9 +32,24 @@ def encode_to_latent(p, image, w, h):
|
|||
# TODO: use caching to make this faster
|
||||
with devices.autocast():
|
||||
vae_output = p.sd_model.encode_first_stage(x)
|
||||
if torch.isnan(vae_output).any():
|
||||
print(f"[FABRIC] NaNs in VAE output found, retrying with 32-bit precision. To always start with 32-bit VAE, use --no-half-vae commandline flag.")
|
||||
devices.dtype_vae = torch.float32
|
||||
x = x.to(devices.dtype_vae)
|
||||
p.sd_model.first_stage_model.to(devices.dtype_vae)
|
||||
vae_output = p.sd_model.encode_first_stage(x)
|
||||
z = p.sd_model.get_first_stage_encoding(vae_output)
|
||||
z = z.to(devices.dtype_unet)
|
||||
return z.squeeze(0)
|
||||
|
||||
def forward_noise(p, x_0, t, noise=None):
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x_0)
|
||||
sqrt_alpha_bar_t = extract_into_tensor(p.sd_model.alphas_cumprod.sqrt(), t, x_0.shape)
|
||||
sqrt_one_minus_alpha_bar_t = extract_into_tensor((1.0 - p.sd_model.alphas_cumprod).sqrt(), t, x_0.shape)
|
||||
x_t = sqrt_alpha_bar_t * x_0 + sqrt_one_minus_alpha_bar_t * noise
|
||||
return x_t
|
||||
|
||||
|
||||
def get_latents_from_params(p, params, width, height):
|
||||
w, h = (width // 8) * 8, (height // 8) * 8
|
||||
|
|
@ -53,8 +76,8 @@ def get_latents_from_params(p, params, width, height):
|
|||
return params.pos_latents, params.neg_latents
|
||||
|
||||
|
||||
def get_curr_feedback_weight(p, params, timestep):
|
||||
progress = 1 - (timestep / (p.sd_model.num_timesteps - 1))
|
||||
def get_curr_feedback_weight(p, params, timestep, num_timesteps=1000):
|
||||
progress = 1 - (timestep / (num_timesteps - 1))
|
||||
if progress >= params.start and progress <= params.end:
|
||||
w = params.max_weight
|
||||
else:
|
||||
|
|
@ -70,9 +93,28 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
if not hasattr(unet, "_fabric_old_forward"):
|
||||
unet._fabric_old_forward = unet.forward
|
||||
|
||||
if isinstance(p.sd_model, LatentDiffusion):
|
||||
sd_version = SD15
|
||||
num_timesteps = p.sd_model.num_timesteps
|
||||
BasicTransformerBlock = ldm.modules.attention.BasicTransformerBlock
|
||||
elif isinstance(p.sd_model, DiffusionEngine):
|
||||
sd_version = SDXL
|
||||
num_timesteps = len(p.sd_model.alphas_cumprod)
|
||||
BasicTransformerBlock = sgm.modules.attention.BasicTransformerBlock
|
||||
else:
|
||||
raise ValueError(f"[FABRIC] Unsupported SD model: {type(p.sd_model)}")
|
||||
|
||||
batch_size = p.batch_size
|
||||
|
||||
null_ctx = p.sd_model.get_learned_conditioning([""]).to(devices.device, dtype=devices.dtype_unet)
|
||||
null_ctx = p.sd_model.get_learned_conditioning([""])
|
||||
if isinstance(null_ctx, torch.Tensor): # SD1.5
|
||||
null_ctx = null_ctx.to(devices.device, dtype=devices.dtype_unet)
|
||||
elif isinstance(null_ctx, dict): # SDXL
|
||||
for key in null_ctx:
|
||||
if isinstance(null_ctx[key], torch.Tensor):
|
||||
null_ctx[key] = null_ctx[key].to(devices.device, dtype=devices.dtype_unet)
|
||||
else:
|
||||
raise ValueError(f"[FABRIC] Unsupported context type: {type(null_ctx)}")
|
||||
|
||||
width = (p.width // 8) * 8
|
||||
height = (p.height // 8) * 8
|
||||
|
|
@ -99,8 +141,7 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
}
|
||||
|
||||
def new_forward(self, x, timesteps=None, context=None, **kwargs):
|
||||
_, uncond_ids, context = unmark_prompt_context(context)
|
||||
cond_ids = [i for i in range(context.size(0)) if i not in uncond_ids]
|
||||
_, uncond_ids, cond_ids, context = unmark_prompt_context(context)
|
||||
has_cond = len(cond_ids) > 0
|
||||
has_uncond = len(uncond_ids) > 0
|
||||
|
||||
|
|
@ -111,7 +152,7 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
print("[FABRIC] Skipping feedback during high-res fix")
|
||||
return self._fabric_old_forward(x, timesteps, context, **kwargs)
|
||||
|
||||
pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item())
|
||||
pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item(), num_timesteps=num_timesteps)
|
||||
if pos_weight <= 0 and neg_weight <= 0:
|
||||
return self._fabric_old_forward(x, timesteps, context, **kwargs)
|
||||
|
||||
|
|
@ -132,11 +173,16 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
return self._fabric_old_forward(x, timesteps, context, **kwargs)
|
||||
|
||||
# add noise to reference latents
|
||||
all_zs = []
|
||||
for latent in all_latents:
|
||||
z = p.sd_model.q_sample(latent.unsqueeze(0), torch.round(timesteps.float()).long())[0]
|
||||
all_zs.append(z)
|
||||
all_zs = torch.stack(all_zs, dim=0)
|
||||
if sd_version == SD15:
|
||||
all_zs = []
|
||||
for latent in all_latents:
|
||||
z = p.sd_model.q_sample(latent.unsqueeze(0), timesteps[0].unsqueeze(0))[0]
|
||||
all_zs.append(z)
|
||||
all_zs = torch.stack(all_zs, dim=0)
|
||||
else: # SDXL
|
||||
xs_0 = torch.stack(all_latents, dim=0)
|
||||
ts = timesteps[0, None].expand(xs_0.size(0)) # (bs,)
|
||||
all_zs = forward_noise(p, xs_0, torch.round(ts.float()).long())
|
||||
|
||||
# save original forward pass
|
||||
for module in self.modules():
|
||||
|
|
@ -178,8 +224,13 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype)
|
||||
ts = timesteps[:1].expand(zs.size(0)) # (bs,)
|
||||
# use the null prompt for pre-computing hidden states on feedback images
|
||||
ctx = null_ctx.expand(zs.size(0), -1, -1) # (bs, p_seq, p_dim)
|
||||
_ = self._fabric_old_forward(zs, ts, ctx)
|
||||
ctx_args = {}
|
||||
if sd_version == SD15:
|
||||
ctx_args["context"] = null_ctx.expand(zs.size(0), -1, -1) # (bs, seq_len, d_model)
|
||||
else: # SDXL
|
||||
ctx_args["context"] = null_ctx["crossattn"].expand(zs.size(0), -1, -1) # (bs, seq_len, d_model)
|
||||
ctx_args["y"] = null_ctx["vector"].expand(zs.size(0), -1) # (bs, d_vector)
|
||||
_ = self._fabric_old_forward(zs, ts, **ctx_args)
|
||||
|
||||
num_pos = len(pos_latents)
|
||||
num_neg = len(neg_latents)
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ def weighted_attention(self, attn_fn, x, context=None, weights=None, **kwargs):
|
|||
return attn_fn(x, context=context, **kwargs)
|
||||
|
||||
weighted_attn_fn = get_weighted_attn_fn()
|
||||
return weighted_attn_fn(self, x, context=context, weights=weights, **kwargs)
|
||||
return weighted_attn_fn(self, x, context=context, weights=weights, mask=kwargs.get('mask', None))
|
||||
|
||||
|
||||
def _get_attn_bias(weights, shape=None, dtype=torch.float32):
|
||||
|
|
|
|||
Loading…
Reference in New Issue