pull/28/head
dvruette 2023-08-25 12:19:34 +02:00
parent 8529e075c7
commit 015d810c23
1 changed files with 1 additions and 1 deletions

View File

@ -183,7 +183,7 @@ def patch_unet_forward_pass(p, unet, params):
ctx_uncond = torch.cat([context[uncond_ids], neg_hs], dim=1) # (n_uncond, seq * (1 + n_neg), dim)
ws = torch.ones_like(ctx_uncond[0, :, 0]) # (seq * (1 + n_neg),)
ws[x_uncond.size(1):] = neg_weight
out_uncond = weighted_attention(attn1, attn1._fabric_old_forward, x_uncond, ctx_uncond, **kwargs) # (n_uncond, seq, dim)
out_uncond = weighted_attention(attn1, attn1._fabric_old_forward, x_uncond, ctx_uncond, ws, **kwargs) # (n_uncond, seq, dim)
outs.append(out_uncond)
out = torch.cat(outs, dim=0)
return out