fix #26
parent
8529e075c7
commit
015d810c23
|
|
@ -183,7 +183,7 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
ctx_uncond = torch.cat([context[uncond_ids], neg_hs], dim=1) # (n_uncond, seq * (1 + n_neg), dim)
|
||||
ws = torch.ones_like(ctx_uncond[0, :, 0]) # (seq * (1 + n_neg),)
|
||||
ws[x_uncond.size(1):] = neg_weight
|
||||
out_uncond = weighted_attention(attn1, attn1._fabric_old_forward, x_uncond, ctx_uncond, **kwargs) # (n_uncond, seq, dim)
|
||||
out_uncond = weighted_attention(attn1, attn1._fabric_old_forward, x_uncond, ctx_uncond, ws, **kwargs) # (n_uncond, seq, dim)
|
||||
outs.append(out_uncond)
|
||||
out = torch.cat(outs, dim=0)
|
||||
return out
|
||||
|
|
|
|||
Loading…
Reference in New Issue